From 2426799113b8b916d1e2ea8c4e7ab8265d15ca87 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Tue, 9 Jan 2024 01:02:01 +0100 Subject: [PATCH 01/75] CheckIn - created DownSubBlocks --- src/diffusers/models/controlnet_xs.py | 1153 +++++++++++++++++++++++++ 1 file changed, 1153 insertions(+) create mode 100644 src/diffusers/models/controlnet_xs.py diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py new file mode 100644 index 000000000000..cf1e5c7d2f33 --- /dev/null +++ b/src/diffusers/models/controlnet_xs.py @@ -0,0 +1,1153 @@ +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import functional as F +from torch.nn.modules.normalization import GroupNorm + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, logging, is_torch_version +from .attention_processor import ( + AttentionProcessor, +) +from .autoencoders import AutoencoderKL +from .lora import LoRACompatibleConv +from .embeddings import ( + TimestepEmbedding, + Timesteps, +) +from .modeling_utils import ModelMixin +from .unet_2d_blocks import ( + CrossAttnDownBlock2D, + DownBlock2D, + Downsample2D, + ResnetBlock2D, + Transformer2DModel, + UNetMidBlock2DCrossAttn, +) +from .unet_2d_condition import UNet2DConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class ControlNetXSOutput(BaseOutput): + """ + The output of [`ControlNetXSModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The output of the `ControlNetXSModel`. Unlike `ControlNetOutput` this is NOT to be added to the base model + output, but is already the final output. + """ + + sample: torch.FloatTensor = None + + +# copied from diffusers.models.controlnet.ControlNetConditioningEmbedding +class ControlNetConditioningEmbedding(nn.Module): + """ + Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN + [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized + training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the + convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides + (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full + model) to encode image-space conditions ... into feature maps ..." + """ + + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), + ): + super().__init__() + + self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + + self.blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = zero_module( + nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + + +class ControlNetXSAddon(ModelMixin, ConfigMixin): + @classmethod + def init_original(cls, sd_type): + kwargs = {} + if sd_type == "sdxl": + kwargs.update({ + 'addition_embed_type': "text_time", + 'addition_time_embed_dim': 256, + 'attention_head_dim': [5, 10, 20], + 'block_out_channels': [320, 640, 1280], + 'cross_attention_dim': 2048, + 'down_block_types': ['DownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D'], + 'projection_class_embeddings_input_dim': 2816, + 'sample_size': 128, + 'transformer_layers_per_block': [1, 2, 10], + 'up_block_types': ['CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'UpBlock2D'], + 'upcast_attention': None, + }) + elif sd_type == "sd": + kwargs.update({ + 'addition_embed_type': None, + 'addition_time_embed_dim': None, + 'attention_head_dim': [5, 10, 20, 20], + 'block_out_channels': [320, 640, 1280, 1280], + 'cross_attention_dim': 1024, + 'down_block_types': ['CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'], + 'projection_class_embeddings_input_dim': None, + 'sample_size': 96, + 'transformer_layers_per_block': 1, + 'up_block_types': ['UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'], + 'upcast_attention': True + }) + else: + raise ValueError("`sd_type` needs to either 'sd' or 'sdxl'") + + return ControlNetXSAddon(**kwargs) + + @register_to_config + def __init__( + self, + channels_from_base_model: List[int], + time_embedding_input_dim: int = 320, + time_embedding_dim: int = 1280, + time_embedding_mix: float = 1.0, + learn_embedding: bool = False, + base_model_channel_sizes: Dict[str, List[Tuple[int]]] = { + "down": [ + (4, 320), + (320, 320), + (320, 320), + (320, 320), + (320, 640), + (640, 640), + (640, 640), + (640, 1280), + (1280, 1280), + ], + "mid": [(1280, 1280)], + "up": [ + (2560, 1280), + (2560, 1280), + (1920, 1280), + (1920, 640), + (1280, 640), + (960, 640), + (960, 320), + (640, 320), + (640, 320), + ], + }, + addition_embed_type = None, + addition_time_embed_dim = None, + attention_head_dim = [5, 10, 20, 20], + block_out_channels = [320, 640, 1280, 1280], + cross_attention_dim = 1024, + down_block_types = ['CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'], + projection_class_embeddings_input_dim = None, + sample_size = 96, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + upcast_attention = True, + ): + super().__init__() + + # todo: + # replace model surgery + # - 2.2 Allow for information infusion from base model + # - 2.3 Make group norms work with modified channel sizes + # add connections + + self.sample_size = sample_size + + # `num_attention_heads` defaults to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = attention_head_dim + + # Check inputs + # todo + + # input + self.conv_in = nn.Conv2d(4, block_out_channels[0], kernel_size=3, padding=1) + + # time + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0) + + # note umer: here `time_embedding_input_dim` is used, so time info can be received from base model + self.time_embedding = TimestepEmbedding(time_embedding_input_dim, time_embed_dim) + + self.encoder_hid_proj = None + + # class embedding + self.class_embedding = None + + if addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + self.time_embed_act = None + + self.down_subblocks = nn.ModuleList([]) + self.up_subblocks = nn.ModuleList([]) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + use_crossattention = down_block_type == "CrossAttnDownBlock2D" + + self.down_subblocks.append(DownSubBlock2D( + has_resnet=True, + has_crossattn=use_crossattention, + in_channels=input_channel + 0, # todo add channels from base model + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + transformer_layers_per_block=transformer_layers_per_block[i], + num_attention_heads=num_attention_heads[i], + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention + )) + self.down_subblocks.append(DownSubBlock2D( + has_resnet=True, + has_crossattn=use_crossattention, + in_channels=output_channel + 0, # todo add channels from base model + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + transformer_layers_per_block=transformer_layers_per_block[i], + num_attention_heads=num_attention_heads[i], + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention + )) + self.down_subblocks.append(DownSubBlock2D( + has_downsampler=True, + in_channels=output_channel + 0, # todo add channels from base model + out_channels=output_channel, + )) + + # mid + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=0.0, + resnet_eps=1e-05, + resnet_act_fn="silu", + output_scale_factor=1, + resnet_time_scale_shift="default", + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[-1], + resnet_groups=32, + dual_cross_attention=False, + use_linear_projection=True, + upcast_attention=upcast_attention, + attention_type="default", + ) + + # todo: connections + # 3 - Gather Channel Sizes + self.ch_inout_ctrl = ControlNetXSModel._gather_subblock_sizes(self.control_model, base_or_control="control") + self.ch_inout_base = base_model_channel_sizes + + # 4 - Build connections between base and control model + self.down_zero_convs_out = nn.ModuleList([]) + self.down_zero_convs_in = nn.ModuleList([]) + self.middle_block_out = nn.ModuleList([]) + self.middle_block_in = nn.ModuleList([]) + self.up_zero_convs_out = nn.ModuleList([]) + self.up_zero_convs_in = nn.ModuleList([]) + + for ch_io_base in self.ch_inout_base["down"]: + self.down_zero_convs_in.append(self._make_zero_conv(in_channels=ch_io_base[1], out_channels=ch_io_base[1])) + for i in range(len(self.ch_inout_ctrl["down"])): + self.down_zero_convs_out.append( + self._make_zero_conv(self.ch_inout_ctrl["down"][i][1], self.ch_inout_base["down"][i][1]) + ) + + self.middle_block_out = self._make_zero_conv( + self.ch_inout_ctrl["mid"][-1][1], self.ch_inout_base["mid"][-1][1] + ) + + self.up_zero_convs_out.append( + self._make_zero_conv(self.ch_inout_ctrl["down"][-1][1], self.ch_inout_base["mid"][-1][1]) + ) + for i in range(1, len(self.ch_inout_ctrl["down"])): + self.up_zero_convs_out.append( + self._make_zero_conv(self.ch_inout_ctrl["down"][-(i + 1)][1], self.ch_inout_base["up"][i - 1][1]) + ) + + + def forward(self, sample, encoder_hidden_states, added_cond_kwargs = {}): + #raise ValueError("A ControlNetXSAddonModel cannot be run by itself. Pass it into a ControlNetXSModel model instead.") + + timestep = 980 + cross_attention_kwargs = {} + timestep_cond = None + + # # # unet.forward for testing + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype) + encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1) + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=1.0) + + down_block_res_samples += res_samples + + # 4. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = self.mid_block(sample, emb) + + return sample + + +class ControlNetXSModel(ModelMixin, ConfigMixin): + r""" + A ControlNet-XS model + + This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic + methods implemented for all models (such as downloading or saving). + + Most of parameters for this model are passed into the [`UNet2DConditionModel`] it creates. Check the documentation + of [`UNet2DConditionModel`] for them. + + Parameters: + conditioning_channels (`int`, defaults to 3): + Number of channels of conditioning input (e.g. an image) + controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `controlnet_cond_embedding` layer. + time_embedding_input_dim (`int`, defaults to 320): + Dimension of input into time embedding. Needs to be same as in the base model. + time_embedding_dim (`int`, defaults to 1280): + Dimension of output from time embedding. Needs to be same as in the base model. + learn_embedding (`bool`, defaults to `False`): + Whether to use time embedding of the control model. If yes, the time embedding is a linear interpolation of + the time embeddings of the control and base model with interpolation parameter `time_embedding_mix**3`. + time_embedding_mix (`float`, defaults to 1.0): + Linear interpolation parameter used if `learn_embedding` is `True`. A value of 1.0 means only the + control model's time embedding will be used. A value of 0.0 means only the base model's time embedding will be used. + base_model_channel_sizes (`Dict[str, List[Tuple[int]]]`): + Channel sizes of each subblock of base model. Use `gather_subblock_sizes` on your base model to compute it. + """ + + @classmethod + def init_original(cls, base_model: UNet2DConditionModel, is_sdxl=True): + """ + Create a ControlNetXS model with the same parameters as in the original paper (https://github.com/vislearn/ControlNet-XS). + + Parameters: + base_model (`UNet2DConditionModel`): + Base UNet model. Needs to be either StableDiffusion or StableDiffusion-XL. + is_sdxl (`bool`, defaults to `True`): + Whether passed `base_model` is a StableDiffusion-XL model. + """ + + def get_dim_attn_heads(base_model: UNet2DConditionModel, size_ratio: float, num_attn_heads: int): + """ + Currently, diffusers can only set the dimension of attention heads (see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why). + The original ControlNet-XS model, however, define the number of attention heads. + That's why compute the dimensions needed to get the correct number of attention heads. + """ + block_out_channels = [int(size_ratio * c) for c in base_model.config.block_out_channels] + dim_attn_heads = [math.ceil(c / num_attn_heads) for c in block_out_channels] + return dim_attn_heads + + if is_sdxl: + return ControlNetXSModel.from_unet( + base_model, + time_embedding_mix=0.95, + learn_embedding=True, + size_ratio=0.1, + conditioning_embedding_out_channels=(16, 32, 96, 256), + num_attention_heads=get_dim_attn_heads(base_model, 0.1, 64), + ) + else: + return ControlNetXSModel.from_unet( + base_model, + time_embedding_mix=1.0, + learn_embedding=True, + size_ratio=0.0125, + conditioning_embedding_out_channels=(16, 32, 96, 256), + num_attention_heads=get_dim_attn_heads(base_model, 0.0125, 8), + ) + + @classmethod + def _gather_subblock_sizes(cls, unet: UNet2DConditionModel, base_or_control: str): + """To create correctly sized connections between base and control model, we need to know + the input and output channels of each subblock. + + Parameters: + unet (`UNet2DConditionModel`): + Unet of which the subblock channels sizes are to be gathered. + base_or_control (`str`): + Needs to be either "base" or "control". If "base", decoder is also considered. + """ + if base_or_control not in ["base", "control"]: + raise ValueError("`base_or_control` needs to be either `base` or `control`") + + channel_sizes = {"down": [], "mid": [], "up": []} + + # input convolution + channel_sizes["down"].append((unet.conv_in.in_channels, unet.conv_in.out_channels)) + + # encoder blocks + for module in unet.down_blocks: + if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): + for r in module.resnets: + channel_sizes["down"].append((r.in_channels, r.out_channels)) + if module.downsamplers: + channel_sizes["down"].append( + (module.downsamplers[0].channels, module.downsamplers[0].out_channels) + ) + else: + raise ValueError(f"Encountered unknown module of type {type(module)} while creating ControlNet-XS.") + + # middle block + channel_sizes["mid"].append((unet.mid_block.resnets[0].in_channels, unet.mid_block.resnets[0].out_channels)) + + # decoder blocks + if base_or_control == "base": + for module in unet.up_blocks: + if isinstance(module, (CrossAttnUpBlock2D, UpBlock2D)): + for r in module.resnets: + channel_sizes["up"].append((r.in_channels, r.out_channels)) + else: + raise ValueError( + f"Encountered unknown module of type {type(module)} while creating ControlNet-XS." + ) + + return channel_sizes + + @register_to_config + def __init__( + self, + conditioning_channels: int = 3, + conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), + controlnet_conditioning_channel_order: str = "rgb", + time_embedding_input_dim: int = 320, + time_embedding_dim: int = 1280, + time_embedding_mix: float = 1.0, + learn_embedding: bool = False, + base_model_channel_sizes: Dict[str, List[Tuple[int]]] = { + "down": [ + (4, 320), + (320, 320), + (320, 320), + (320, 320), + (320, 640), + (640, 640), + (640, 640), + (640, 1280), + (1280, 1280), + ], + "mid": [(1280, 1280)], + "up": [ + (2560, 1280), + (2560, 1280), + (1920, 1280), + (1920, 640), + (1280, 640), + (960, 640), + (960, 320), + (640, 320), + (640, 320), + ], + }, + sample_size: Optional[int] = None, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + norm_num_groups: Optional[int] = 32, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + num_attention_heads: Optional[Union[int, Tuple[int]]] = 8, + upcast_attention: bool = False, + ): + super().__init__() + + # 1 - Create control unet + self.control_model = UNet2DConditionModel( + sample_size=sample_size, + down_block_types=down_block_types, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + transformer_layers_per_block=transformer_layers_per_block, + attention_head_dim=num_attention_heads, + use_linear_projection=True, + upcast_attention=upcast_attention, + time_embedding_dim=time_embedding_dim, + ) + + # 5 - Create conditioning hint embedding + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + ) + + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + conditioning_channels: int = 3, + conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), + controlnet_conditioning_channel_order: str = "rgb", + learn_embedding: bool = False, + time_embedding_mix: float = 1.0, + block_out_channels: Optional[Tuple[int]] = None, + size_ratio: Optional[float] = None, + num_attention_heads: Optional[Union[int, Tuple[int]]] = 8, + norm_num_groups: Optional[int] = None, + ): + r""" + Instantiate a [`ControlNetXSModel`] from [`UNet2DConditionModel`]. + + Parameters: + unet (`UNet2DConditionModel`): + The UNet model we want to control. The dimensions of the ControlNetXSModel will be adapted to it. + conditioning_channels (`int`, defaults to 3): + Number of channels of conditioning input (e.g. an image) + conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `controlnet_cond_embedding` layer. + controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + learn_embedding (`bool`, defaults to `False`): + Wether to use time embedding of the control model. If yes, the time embedding is a linear interpolation + of the time embeddings of the control and base model with interpolation parameter + `time_embedding_mix**3`. + time_embedding_mix (`float`, defaults to 1.0): + Linear interpolation parameter used if `learn_embedding` is `True`. + block_out_channels (`Tuple[int]`, *optional*): + Down blocks output channels in control model. Either this or `size_ratio` must be given. + size_ratio (float, *optional*): + When given, block_out_channels is set to a relative fraction of the base model's block_out_channels. + Either this or `block_out_channels` must be given. + num_attention_heads (`Union[int, Tuple[int]]`, *optional*): + The dimension of the attention heads. The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. + norm_num_groups (int, *optional*, defaults to `None`): + The number of groups to use for the normalization of the control unet. If `None`, + `int(unet.config.norm_num_groups * size_ratio)` is taken. + """ + + # Check input + fixed_size = block_out_channels is not None + relative_size = size_ratio is not None + if not (fixed_size ^ relative_size): + raise ValueError( + "Pass exactly one of `block_out_channels` (for absolute sizing) or `control_model_ratio` (for relative sizing)." + ) + + # Create model + if block_out_channels is None: + block_out_channels = [int(size_ratio * c) for c in unet.config.block_out_channels] + + # Check that attention heads and group norms match channel sizes + # - attention heads + def attn_heads_match_channel_sizes(attn_heads, channel_sizes): + if isinstance(attn_heads, (tuple, list)): + return all(c % a == 0 for a, c in zip(attn_heads, channel_sizes)) + else: + return all(c % attn_heads == 0 for c in channel_sizes) + + num_attention_heads = num_attention_heads or unet.config.attention_head_dim + if not attn_heads_match_channel_sizes(num_attention_heads, block_out_channels): + raise ValueError( + f"The dimension of attention heads ({num_attention_heads}) must divide `block_out_channels` ({block_out_channels}). If you didn't set `num_attention_heads` the default settings don't match your model. Set `num_attention_heads` manually." + ) + + # - group norms + def group_norms_match_channel_sizes(num_groups, channel_sizes): + return all(c % num_groups == 0 for c in channel_sizes) + + if norm_num_groups is None: + if group_norms_match_channel_sizes(unet.config.norm_num_groups, block_out_channels): + norm_num_groups = unet.config.norm_num_groups + else: + norm_num_groups = min(block_out_channels) + + if group_norms_match_channel_sizes(norm_num_groups, block_out_channels): + print( + f"`norm_num_groups` was set to `min(block_out_channels)` (={norm_num_groups}) so it divides all block_out_channels` ({block_out_channels}). Set it explicitly to remove this information." + ) + else: + raise ValueError( + f"`block_out_channels` ({block_out_channels}) don't match the base models `norm_num_groups` ({unet.config.norm_num_groups}). Setting `norm_num_groups` to `min(block_out_channels)` ({norm_num_groups}) didn't fix this. Pass `norm_num_groups` explicitly so it divides all block_out_channels." + ) + + def get_time_emb_input_dim(unet: UNet2DConditionModel): + return unet.time_embedding.linear_1.in_features + + def get_time_emb_dim(unet: UNet2DConditionModel): + return unet.time_embedding.linear_2.out_features + + # Clone params from base unet if + # (i) it's required to build SD or SDXL, and + # (ii) it's not used for the time embedding (as time embedding of control model is never used), and + # (iii) it's not set further below anyway + to_keep = [ + "cross_attention_dim", + "down_block_types", + "sample_size", + "transformer_layers_per_block", + "up_block_types", + "upcast_attention", + ] + kwargs = {k: v for k, v in dict(unet.config).items() if k in to_keep} + kwargs.update(block_out_channels=block_out_channels) + kwargs.update(num_attention_heads=num_attention_heads) + kwargs.update(norm_num_groups=norm_num_groups) + + # Add controlnetxs-specific params + kwargs.update( + conditioning_channels=conditioning_channels, + controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, + time_embedding_input_dim=get_time_emb_input_dim(unet), + time_embedding_dim=get_time_emb_dim(unet), + time_embedding_mix=time_embedding_mix, + learn_embedding=learn_embedding, + base_model_channel_sizes=ControlNetXSModel._gather_subblock_sizes(unet, base_or_control="base"), + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + ) + + return cls(**kwargs) + + def forward( + self, + base_model: UNet2DConditionModel, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: torch.Tensor, + conditioning_scale: float = 1.0, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + return_dict: bool = True, + ) -> Union[ControlNetXSOutput, Tuple]: + """ + The [`ControlNetModel`] forward method. + + Args: + base_model (`UNet2DConditionModel`): + The base unet model we want to control. + sample (`torch.FloatTensor`): + The noisy input tensor. + timestep (`Union[torch.Tensor, float, int]`): + The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states. + controlnet_cond (`torch.FloatTensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + conditioning_scale (`float`, defaults to `1.0`): + How much the control model affects the base model outputs. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): + Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the + timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep + embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + added_cond_kwargs (`dict`): + Additional conditions for the Stable Diffusion XL UNet. + cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): + A kwargs dictionary that if specified is passed along to the `AttnProcessor`. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. + + Returns: + [`~models.controlnetxs.ControlNetXSOutput`] **or** `tuple`: + If `return_dict` is `True`, a [`~models.controlnetxs.ControlNetXSOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + # check channel order + channel_order = self.config.controlnet_conditioning_channel_order + + if channel_order == "rgb": + # in rgb order by default + ... + elif channel_order == "bgr": + controlnet_cond = torch.flip(controlnet_cond, dims=[1]) + else: + raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") + + # scale control strength + n_connections = len(self.down_zero_convs_out) + 1 + len(self.up_zero_convs_out) + scale_list = torch.full((n_connections,), conditioning_scale) + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = base_model.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + if self.config.learn_embedding: + ctrl_temb = self.control_model.time_embedding(t_emb, timestep_cond) + base_temb = base_model.time_embedding(t_emb, timestep_cond) + interpolation_param = self.config.time_embedding_mix**0.3 + + temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param) + else: + temb = base_model.time_embedding(t_emb) + + # added time & text embeddings + aug_emb = None + + if base_model.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if base_model.config.class_embed_type == "timestep": + class_labels = base_model.time_proj(class_labels) + + class_emb = base_model.class_embedding(class_labels).to(dtype=self.dtype) + temb = temb + class_emb + + if base_model.config.addition_embed_type is not None: + if base_model.config.addition_embed_type == "text": + aug_emb = base_model.add_embedding(encoder_hidden_states) + elif base_model.config.addition_embed_type == "text_image": + raise NotImplementedError() + elif base_model.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = base_model.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(temb.dtype) + aug_emb = base_model.add_embedding(add_embeds) + elif base_model.config.addition_embed_type == "image": + raise NotImplementedError() + elif base_model.config.addition_embed_type == "image_hint": + raise NotImplementedError() + + temb = temb + aug_emb if aug_emb is not None else temb + + # text embeddings + cemb = encoder_hidden_states + + # Preparation + guided_hint = self.controlnet_cond_embedding(controlnet_cond) + + h_ctrl = h_base = sample + hs_base, hs_ctrl = [], [] + it_down_convs_in, it_down_convs_out, it_dec_convs_in, it_up_convs_out = map( + iter, (self.down_zero_convs_in, self.down_zero_convs_out, self.up_zero_convs_in, self.up_zero_convs_out) + ) + scales = iter(scale_list) + + base_down_subblocks = to_sub_blocks(base_model.down_blocks) + ctrl_down_subblocks = to_sub_blocks(self.control_model.down_blocks) + base_mid_subblocks = to_sub_blocks([base_model.mid_block]) + ctrl_mid_subblocks = to_sub_blocks([self.control_model.mid_block]) + base_up_subblocks = to_sub_blocks(base_model.up_blocks) + + # Cross Control + # 0 - conv in + h_base = base_model.conv_in(h_base) + h_ctrl = self.control_model.conv_in(h_ctrl) + if guided_hint is not None: + h_ctrl += guided_hint + h_base = h_base + next(it_down_convs_out)(h_ctrl) * next(scales) # D - add ctrl -> base + + hs_base.append(h_base) + hs_ctrl.append(h_ctrl) + + # 1 - down + for m_base, m_ctrl in zip(base_down_subblocks, ctrl_down_subblocks): + h_ctrl = torch.cat([h_ctrl, next(it_down_convs_in)(h_base)], dim=1) # A - concat base -> ctrl + h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock + h_ctrl = m_ctrl(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock + h_base = h_base + next(it_down_convs_out)(h_ctrl) * next(scales) # D - add ctrl -> base + hs_base.append(h_base) + hs_ctrl.append(h_ctrl) + + # 2 - mid + h_ctrl = torch.cat([h_ctrl, next(it_down_convs_in)(h_base)], dim=1) # A - concat base -> ctrl + for m_base, m_ctrl in zip(base_mid_subblocks, ctrl_mid_subblocks): + h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock + h_ctrl = m_ctrl(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock + h_base = h_base + self.middle_block_out(h_ctrl) * next(scales) # D - add ctrl -> base + + # 3 - up + for i, m_base in enumerate(base_up_subblocks): + h_base = h_base + next(it_up_convs_out)(hs_ctrl.pop()) * next(scales) # add info from ctrl encoder + h_base = torch.cat([h_base, hs_base.pop()], dim=1) # concat info from base encoder+ctrl encoder + h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) + + h_base = base_model.conv_norm_out(h_base) + h_base = base_model.conv_act(h_base) + h_base = base_model.conv_out(h_base) + + if not return_dict: + return h_base + + return ControlNetXSOutput(sample=h_base) + + def _make_zero_conv(self, in_channels, out_channels=None): + # keep running track of channels sizes + self.in_channels = in_channels + self.out_channels = out_channels or in_channels + + return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0)) + + +def increase_block_input_in_mid_resnet(unet: UNet2DConditionModel, by): + """Increase channels sizes to allow for additional concatted information from base model""" + m = unet.mid_block.resnets[0] + old_norm1, old_conv1 = m.norm1, m.conv1 + # norm + norm_args = "num_groups num_channels eps affine".split(" ") + for a in norm_args: + assert hasattr(old_norm1, a) + norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args} + norm_kwargs["num_channels"] += by # surgery done here + # conv1 + conv1_args = ( + "in_channels out_channels kernel_size stride padding dilation groups bias padding_mode lora_layer".split(" ") + ) + for a in conv1_args: + assert hasattr(old_conv1, a) + conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args} + conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor. + conv1_kwargs["in_channels"] += by # surgery done here + # conv_shortcut + # as we changed the input size of the block, the input and output sizes are likely different, + # therefore we need a conv_shortcut (simply adding won't work) + conv_shortcut_args_kwargs = { + "in_channels": conv1_kwargs["in_channels"], + "out_channels": conv1_kwargs["out_channels"], + # default arguments from resnet.__init__ + "kernel_size": 1, + "stride": 1, + "padding": 0, + "bias": True, + } + # swap old with new modules + unet.mid_block.resnets[0].norm1 = GroupNorm(**norm_kwargs) + unet.mid_block.resnets[0].conv1 = LoRACompatibleConv(**conv1_kwargs) + unet.mid_block.resnets[0].conv_shortcut = LoRACompatibleConv(**conv_shortcut_args_kwargs) + unet.mid_block.resnets[0].in_channels += by # surgery done here + + +def adjust_group_norms(unet: UNet2DConditionModel, max_num_group: int = 32): + def find_denominator(number, start): + if start >= number: + return number + while start != 0: + residual = number % start + if residual == 0: + return start + start -= 1 + + for block in [*unet.down_blocks, unet.mid_block]: + # resnets + for r in block.resnets: + if r.norm1.num_groups < max_num_group: + r.norm1.num_groups = find_denominator(r.norm1.num_channels, start=max_num_group) + + if r.norm2.num_groups < max_num_group: + r.norm2.num_groups = find_denominator(r.norm2.num_channels, start=max_num_group) + + # transformers + if hasattr(block, "attentions"): + for a in block.attentions: + if a.norm.num_groups < max_num_group: + a.norm.num_groups = find_denominator(a.norm.num_channels, start=max_num_group) + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + + +class DownSubBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: Optional[int] = None, + transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1, + num_attention_heads: Optional[int] = 1, + cross_attention_dim: Optional[int] = 1024, + upcast_attention: Optional[bool] = False, + has_resnet = False, + has_crossattn = False, + has_downsampler = False, + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + if has_resnet: + self.resnet = ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=1e-5, + ) + else: + self.resnet = None + + if has_crossattn: + self.attention = Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + use_linear_projection=True, + upcast_attention=upcast_attention, + ) + else: + self.attention = None + + if has_downsampler: + self.downsampler = Downsample2D(out_channels, use_conv=True, out_channels=out_channels, name="op") + else: + self.downsampler = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + # todo + + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + if self.training and self.gradient_checkpointing: + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = self.attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = self.resnet(hidden_states, temb, scale=lora_scale) + hidden_states = self.attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + return hidden_states \ No newline at end of file From 1c36d0b721bec113cbbcc30623cad73169e358e4 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Tue, 9 Jan 2024 14:04:54 +0100 Subject: [PATCH 02/75] Added extra channels, implemented subblock fwd --- src/diffusers/models/controlnet_xs.py | 271 +++++++++----------------- 1 file changed, 95 insertions(+), 176 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index cf1e5c7d2f33..4542f8997130 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -135,9 +135,10 @@ def init_original(cls, sd_type): def __init__( self, channels_from_base_model: List[int], + conditioning_channels: int = 3, + conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), time_embedding_input_dim: int = 320, time_embedding_dim: int = 1280, - time_embedding_mix: float = 1.0, learn_embedding: bool = False, base_model_channel_sizes: Dict[str, List[Tuple[int]]] = { "down": [ @@ -166,10 +167,11 @@ def __init__( }, addition_embed_type = None, addition_time_embed_dim = None, - attention_head_dim = [5, 10, 20, 20], - block_out_channels = [320, 640, 1280, 1280], + attention_head_dim = [5, 10, 20, 20], + block_out_channels = [32, 64, 128], + base_block_out_channels = [320, 640, 1280], cross_attention_dim = 1024, - down_block_types = ['CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'], + down_block_types = ['CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'], projection_class_embeddings_input_dim = None, sample_size = 96, transformer_layers_per_block: Union[int, Tuple[int]] = 1, @@ -231,6 +233,19 @@ def __init__( blocks_time_embed_dim = time_embed_dim # down + def get_extra_channel(block_no, subblock_no): + """Determine channel size for extra info from base - todo""" + if block_no==0: + # in 1st block: all same - todo + return base_block_out_channels[0] + else: + if subblock_no==0: + # in 2nd+ block: in 1st subblock, no change yet - todo + return base_block_out_channels[block_no-1] + else: + # in 2nd+ block: in 2nd+ subblock, resnet has double channels -> change - todo + return base_block_out_channels[block_no] + output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel @@ -240,7 +255,7 @@ def __init__( self.down_subblocks.append(DownSubBlock2D( has_resnet=True, has_crossattn=use_crossattention, - in_channels=input_channel + 0, # todo add channels from base model + in_channels=input_channel + get_extra_channel(block_no=i, subblock_no=0), out_channels=output_channel, temb_channels=blocks_time_embed_dim, transformer_layers_per_block=transformer_layers_per_block[i], @@ -251,7 +266,7 @@ def __init__( self.down_subblocks.append(DownSubBlock2D( has_resnet=True, has_crossattn=use_crossattention, - in_channels=output_channel + 0, # todo add channels from base model + in_channels=output_channel + get_extra_channel(block_no=i, subblock_no=1), out_channels=output_channel, temb_channels=blocks_time_embed_dim, transformer_layers_per_block=transformer_layers_per_block[i], @@ -259,11 +274,12 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention )) - self.down_subblocks.append(DownSubBlock2D( - has_downsampler=True, - in_channels=output_channel + 0, # todo add channels from base model - out_channels=output_channel, - )) + if i= number: + return number + while start != 0: + residual = number % start + if residual == 0: + return start + start -= 1 + + if norm.num_groups < max_num_group: + norm.num_groups = find_denominator(norm.num_channels, start=max_num_group) + + for subblock in self.down_subblocks: + if subblock.resnet is not None: + adjust_group_norms(subblock.resnet.norm1) + adjust_group_norms(subblock.resnet.norm2) + if subblock.attention is not None: + adjust_group_norms(subblock.attention.norm) + for resnet in self.mid_block.resnets: + adjust_group_norms(resnet.norm1) + adjust_group_norms(resnet.norm2) + for attn in self.mid_block.attentions: + adjust_group_norms(attn.norm) + # todo: connections # 3 - Gather Channel Sizes self.ch_inout_ctrl = ControlNetXSModel._gather_subblock_sizes(self.control_model, base_or_control="control") @@ -293,9 +335,7 @@ def __init__( self.down_zero_convs_out = nn.ModuleList([]) self.down_zero_convs_in = nn.ModuleList([]) self.middle_block_out = nn.ModuleList([]) - self.middle_block_in = nn.ModuleList([]) self.up_zero_convs_out = nn.ModuleList([]) - self.up_zero_convs_in = nn.ModuleList([]) for ch_io_base in self.ch_inout_base["down"]: self.down_zero_convs_in.append(self._make_zero_conv(in_channels=ch_io_base[1], out_channels=ch_io_base[1])) @@ -316,133 +356,16 @@ def __init__( self._make_zero_conv(self.ch_inout_ctrl["down"][-(i + 1)][1], self.ch_inout_base["up"][i - 1][1]) ) + # 5 - Create conditioning hint embedding + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + ) - def forward(self, sample, encoder_hidden_states, added_cond_kwargs = {}): - #raise ValueError("A ControlNetXSAddonModel cannot be run by itself. Pass it into a ControlNetXSModel model instead.") - - timestep = 980 - cross_attention_kwargs = {} - timestep_cond = None - - # # # unet.forward for testing - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) - - t_emb = self.time_proj(timesteps) - - # `Timesteps` does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=sample.dtype) - - emb = self.time_embedding(t_emb, timestep_cond) - aug_emb = None - - if self.class_embedding is not None: - if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") - - if self.config.class_embed_type == "timestep": - class_labels = self.time_proj(class_labels) - - # `Timesteps` does not contain any weights and will always return f32 tensors - # there might be better ways to encapsulate this. - class_labels = class_labels.to(dtype=sample.dtype) - - class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) - - if self.config.class_embeddings_concat: - emb = torch.cat([emb, class_emb], dim=-1) - else: - emb = emb + class_emb - - if self.config.addition_embed_type == "text": - aug_emb = self.add_embedding(encoder_hidden_states) - elif self.config.addition_embed_type == "text_time": - # SDXL - style - if "text_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" - ) - text_embeds = added_cond_kwargs.get("text_embeds") - if "time_ids" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" - ) - time_ids = added_cond_kwargs.get("time_ids") - time_embeds = self.add_time_proj(time_ids.flatten()) - time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) - add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) - add_embeds = add_embeds.to(emb.dtype) - aug_emb = self.add_embedding(add_embeds) - - emb = emb + aug_emb if aug_emb is not None else emb - - if self.time_embed_act is not None: - emb = self.time_embed_act(emb) - - if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": - encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) - - elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": - if "image_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" - ) - image_embeds = added_cond_kwargs.get("image_embeds") - image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype) - encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1) - - # 2. pre-process - sample = self.conv_in(sample) - - # 3. down - down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: - if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: - # For t2i-adapter CrossAttnDownBlock2D - additional_residuals = {} - - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - **additional_residuals, - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=1.0) - - down_block_res_samples += res_samples - - # 4. mid - if self.mid_block is not None: - if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: - sample = self.mid_block( - sample, - emb, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - ) - else: - sample = self.mid_block(sample, emb) - return sample + def forward(self, *args, **kwargs): + raise ValueError("A ControlNetXSAddonModel cannot be run by itself. Pass it into a ControlNetXSModel model instead.") class ControlNetXSModel(ModelMixin, ConfigMixin): @@ -567,8 +490,6 @@ def _gather_subblock_sizes(cls, unet: UNet2DConditionModel, base_or_control: str @register_to_config def __init__( self, - conditioning_channels: int = 3, - conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), controlnet_conditioning_channel_order: str = "rgb", time_embedding_input_dim: int = 320, time_embedding_dim: int = 1280, @@ -631,13 +552,6 @@ def __init__( time_embedding_dim=time_embedding_dim, ) - # 5 - Create conditioning hint embedding - self.controlnet_cond_embedding = ControlNetConditioningEmbedding( - conditioning_embedding_channels=block_out_channels[0], - block_out_channels=conditioning_embedding_out_channels, - conditioning_channels=conditioning_channels, - ) - @classmethod def from_unet( cls, @@ -1051,7 +965,6 @@ def zero_module(module): return module - class DownSubBlock2D(nn.Module): def __init__( self, @@ -1110,8 +1023,6 @@ def forward( cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: - # todo - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 if self.training and self.gradient_checkpointing: @@ -1124,30 +1035,38 @@ def custom_forward(*inputs): return custom_forward - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) - hidden_states = self.attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] + if self.resnet is not None: + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + if self.attention is not None: + hidden_states = self.attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + if self.downsampler is not None: + hidden_states = self.downsampler(hidden_states) else: - hidden_states = self.resnet(hidden_states, temb, scale=lora_scale) - hidden_states = self.attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - - return hidden_states \ No newline at end of file + if self.resnet is not None: + hidden_states = self.resnet(hidden_states, temb, scale=lora_scale) + if self.attention is not None: + hidden_states = self.attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + if self.downsampler is not None: + hidden_states = self.downsampler(hidden_states) + + return hidden_states From 9626ce3171705ea125de8cc80c99637d5fffcafe Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Wed, 10 Jan 2024 12:39:56 +0100 Subject: [PATCH 03/75] Fixed connection sizes --- src/diffusers/models/controlnet_xs.py | 541 +++++++++----------------- 1 file changed, 194 insertions(+), 347 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 4542f8997130..771e20cd833d 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -97,6 +97,7 @@ def forward(self, conditioning): class ControlNetXSAddon(ModelMixin, ConfigMixin): @classmethod def init_original(cls, sd_type): + # todo kwargs = {} if sd_type == "sdxl": kwargs.update({ @@ -134,56 +135,33 @@ def init_original(cls, sd_type): @register_to_config def __init__( self, - channels_from_base_model: List[int], conditioning_channels: int = 3, conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), time_embedding_input_dim: int = 320, time_embedding_dim: int = 1280, learn_embedding: bool = False, base_model_channel_sizes: Dict[str, List[Tuple[int]]] = { - "down": [ - (4, 320), - (320, 320), - (320, 320), - (320, 320), - (320, 640), - (640, 640), - (640, 640), - (640, 1280), - (1280, 1280), - ], - "mid": [(1280, 1280)], - "up": [ - (2560, 1280), - (2560, 1280), - (1920, 1280), - (1920, 640), - (1280, 640), - (960, 640), - (960, 320), - (640, 320), - (640, 320), - ], + "down - in": [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280], + "down - out": [320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280], + "mid": 1280, + "up - in": [1280, 1280, 1280, 1280,1280, 1280, 1280, 640, 640, 640, 320, 320], }, addition_embed_type = None, addition_time_embed_dim = None, - attention_head_dim = [5, 10, 20, 20], - block_out_channels = [32, 64, 128], - base_block_out_channels = [320, 640, 1280], + attention_head_dim = [4], + block_out_channels = [4, 8, 16, 16], + base_block_out_channels = [320, 640, 1280, 1280], cross_attention_dim = 1024, - down_block_types = ['CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'], + down_block_types = ['CrossAttnDownBlock2D', 'CrossAttnDownBlock2D','CrossAttnDownBlock2D', 'DownBlock2D'], projection_class_embeddings_input_dim = None, sample_size = 96, transformer_layers_per_block: Union[int, Tuple[int]] = 1, upcast_attention = True, + norm_num_groups = 4, ): super().__init__() - # todo: - # replace model surgery - # - 2.2 Allow for information infusion from base model - # - 2.3 Make group norms work with modified channel sizes - # add connections + # todo: learn_embedding self.sample_size = sample_size @@ -202,17 +180,10 @@ def __init__( # time time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0) - # note umer: here `time_embedding_input_dim` is used, so time info can be received from base model self.time_embedding = TimestepEmbedding(time_embedding_input_dim, time_embed_dim) - self.encoder_hid_proj = None - - # class embedding - self.class_embedding = None - if addition_embed_type == "text_time": self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0) self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) @@ -261,7 +232,8 @@ def get_extra_channel(block_no, subblock_no): transformer_layers_per_block=transformer_layers_per_block[i], num_attention_heads=num_attention_heads[i], cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention + upcast_attention=upcast_attention, + norm_num_groups=norm_num_groups, )) self.down_subblocks.append(DownSubBlock2D( has_resnet=True, @@ -272,7 +244,8 @@ def get_extra_channel(block_no, subblock_no): transformer_layers_per_block=transformer_layers_per_block[i], num_attention_heads=num_attention_heads[i], cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention + upcast_attention=upcast_attention, + norm_num_groups=norm_num_groups, )) if i= number: - return number - while start != 0: - residual = number % start - if residual == 0: - return start - start -= 1 - - if norm.num_groups < max_num_group: - norm.num_groups = find_denominator(norm.num_channels, start=max_num_group) - - for subblock in self.down_subblocks: - if subblock.resnet is not None: - adjust_group_norms(subblock.resnet.norm1) - adjust_group_norms(subblock.resnet.norm2) - if subblock.attention is not None: - adjust_group_norms(subblock.attention.norm) - for resnet in self.mid_block.resnets: - adjust_group_norms(resnet.norm1) - adjust_group_norms(resnet.norm2) - for attn in self.mid_block.attentions: - adjust_group_norms(attn.norm) - - # todo: connections # 3 - Gather Channel Sizes - self.ch_inout_ctrl = ControlNetXSModel._gather_subblock_sizes(self.control_model, base_or_control="control") + conditioning_embedding_out_channels + self.ch_inout_ctrl = { + "down - out": [s.out_channels for s in self.down_subblocks], + "mid - out": self.down_subblocks[-1].out_channels + } self.ch_inout_base = base_model_channel_sizes # 4 - Build connections between base and control model self.down_zero_convs_out = nn.ModuleList([]) self.down_zero_convs_in = nn.ModuleList([]) - self.middle_block_out = nn.ModuleList([]) + self.middle_zero_convs_out = nn.ModuleList([]) self.up_zero_convs_out = nn.ModuleList([]) + + # 4.1 - Connections from base encoder to ctrl encoder + # Information is passed from base to ctrl _before_ each subblock. We therefore use the 'in' channels. + # As the information is concatted in ctrl, we don't need to change channel sizes. So channels in = channels out. + for c in base_model_channel_sizes['down - in']: + self.down_zero_convs_in.append(self._make_zero_conv(c, c)) + c = base_model_channel_sizes['mid'] + self.down_zero_convs_in.append(self._make_zero_conv(c, c)) + + # 4.2 - Connections from ctrl encoder to base encoder + # Information is passed from ctrl to base _after_ each subblock. We therefore use the 'out' channels. + # As the information is added to base, the out-channels need to match base. + for i in range(len(self.down_subblocks)): + ch_base_out = base_model_channel_sizes['down - out'][i] + ch_ctrl_out = self.ch_inout_ctrl['down - out'][i] + if i==0: + # for conv_in + self.down_zero_convs_out.append(self._make_zero_conv(self.conv_in.out_channels, ch_base_out)) + self.down_zero_convs_out.append(self._make_zero_conv(ch_ctrl_out, ch_base_out)) + + # 4.3 - Connections in mid block + # todo + ch_base_out = base_model_channel_sizes['mid - out'] + ch_ctrl_out = self.ch_inout_ctrl['mid - out'] + self.middle_zero_convs_out = self._make_zero_conv(ch_ctrl_out, ch_base_out) - for ch_io_base in self.ch_inout_base["down"]: - self.down_zero_convs_in.append(self._make_zero_conv(in_channels=ch_io_base[1], out_channels=ch_io_base[1])) - for i in range(len(self.ch_inout_ctrl["down"])): - self.down_zero_convs_out.append( - self._make_zero_conv(self.ch_inout_ctrl["down"][i][1], self.ch_inout_base["down"][i][1]) - ) - - self.middle_block_out = self._make_zero_conv( - self.ch_inout_ctrl["mid"][-1][1], self.ch_inout_base["mid"][-1][1] - ) - - self.up_zero_convs_out.append( - self._make_zero_conv(self.ch_inout_ctrl["down"][-1][1], self.ch_inout_base["mid"][-1][1]) - ) - for i in range(1, len(self.ch_inout_ctrl["down"])): - self.up_zero_convs_out.append( - self._make_zero_conv(self.ch_inout_ctrl["down"][-(i + 1)][1], self.ch_inout_base["up"][i - 1][1]) - ) + # 4.3 - Connections from ctrl encoder to base decoder + # todo + skip_channels = reversed([self.conv_in.out_channels] + self.ch_inout_ctrl['down - out']) + for s,i in zip(skip_channels, base_model_channel_sizes['up - in']): + self.up_zero_convs_out.append(self._make_zero_conv(s, i)) # 5 - Create conditioning hint embedding self.controlnet_cond_embedding = ControlNetConditioningEmbedding( @@ -363,10 +325,139 @@ def find_denominator(number, start): conditioning_channels=conditioning_channels, ) - def forward(self, *args, **kwargs): raise ValueError("A ControlNetXSAddonModel cannot be run by itself. Pass it into a ControlNetXSModel model instead.") + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + conditioning_channels: int = 3, + conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), + controlnet_conditioning_channel_order: str = "rgb", + learn_embedding: bool = False, + time_embedding_mix: float = 1.0, + block_out_channels: Optional[Tuple[int]] = None, + size_ratio: Optional[float] = None, + num_attention_heads: Optional[Union[int, Tuple[int]]] = 8, + norm_num_groups: Optional[int] = None, + ): + # todo + r""" + Instantiate a [`ControlNetXSModel`] from [`UNet2DConditionModel`]. + + Parameters: + unet (`UNet2DConditionModel`): + The UNet model we want to control. The dimensions of the ControlNetXSModel will be adapted to it. + conditioning_channels (`int`, defaults to 3): + Number of channels of conditioning input (e.g. an image) + conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `controlnet_cond_embedding` layer. + controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + learn_embedding (`bool`, defaults to `False`): + Wether to use time embedding of the control model. If yes, the time embedding is a linear interpolation + of the time embeddings of the control and base model with interpolation parameter + `time_embedding_mix**3`. + time_embedding_mix (`float`, defaults to 1.0): + Linear interpolation parameter used if `learn_embedding` is `True`. + block_out_channels (`Tuple[int]`, *optional*): + Down blocks output channels in control model. Either this or `size_ratio` must be given. + size_ratio (float, *optional*): + When given, block_out_channels is set to a relative fraction of the base model's block_out_channels. + Either this or `block_out_channels` must be given. + num_attention_heads (`Union[int, Tuple[int]]`, *optional*): + The dimension of the attention heads. The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. + norm_num_groups (int, *optional*, defaults to `None`): + The number of groups to use for the normalization of the control unet. If `None`, + `int(unet.config.norm_num_groups * size_ratio)` is taken. + """ + + # Check input + fixed_size = block_out_channels is not None + relative_size = size_ratio is not None + if not (fixed_size ^ relative_size): + raise ValueError( + "Pass exactly one of `block_out_channels` (for absolute sizing) or `control_model_ratio` (for relative sizing)." + ) + + # Create model + if block_out_channels is None: + block_out_channels = [int(size_ratio * c) for c in unet.config.block_out_channels] + + # Check that attention heads and group norms match channel sizes + # - attention heads + def attn_heads_match_channel_sizes(attn_heads, channel_sizes): + if isinstance(attn_heads, (tuple, list)): + return all(c % a == 0 for a, c in zip(attn_heads, channel_sizes)) + else: + return all(c % attn_heads == 0 for c in channel_sizes) + + num_attention_heads = num_attention_heads or unet.config.attention_head_dim + if not attn_heads_match_channel_sizes(num_attention_heads, block_out_channels): + raise ValueError( + f"The dimension of attention heads ({num_attention_heads}) must divide `block_out_channels` ({block_out_channels}). If you didn't set `num_attention_heads` the default settings don't match your model. Set `num_attention_heads` manually." + ) + + # - group norms + def group_norms_match_channel_sizes(num_groups, channel_sizes): + return all(c % num_groups == 0 for c in channel_sizes) + + if norm_num_groups is None: + if group_norms_match_channel_sizes(unet.config.norm_num_groups, block_out_channels): + norm_num_groups = unet.config.norm_num_groups + else: + norm_num_groups = min(block_out_channels) + + if group_norms_match_channel_sizes(norm_num_groups, block_out_channels): + print( + f"`norm_num_groups` was set to `min(block_out_channels)` (={norm_num_groups}) so it divides all block_out_channels` ({block_out_channels}). Set it explicitly to remove this information." + ) + else: + raise ValueError( + f"`block_out_channels` ({block_out_channels}) don't match the base models `norm_num_groups` ({unet.config.norm_num_groups}). Setting `norm_num_groups` to `min(block_out_channels)` ({norm_num_groups}) didn't fix this. Pass `norm_num_groups` explicitly so it divides all block_out_channels." + ) + + def get_time_emb_input_dim(unet: UNet2DConditionModel): + return unet.time_embedding.linear_1.in_features + + def get_time_emb_dim(unet: UNet2DConditionModel): + return unet.time_embedding.linear_2.out_features + + # Clone params from base unet if + # (i) it's required to build SD or SDXL, and + # (ii) it's not used for the time embedding (as time embedding of control model is never used), and + # (iii) it's not set further below anyway + to_keep = [ + "cross_attention_dim", + "down_block_types", + "sample_size", + "transformer_layers_per_block", + "up_block_types", + "upcast_attention", + ] + kwargs = {k: v for k, v in dict(unet.config).items() if k in to_keep} + kwargs.update(block_out_channels=block_out_channels) + kwargs.update(num_attention_heads=num_attention_heads) + kwargs.update(norm_num_groups=norm_num_groups) + + # Add controlnetxs-specific params + kwargs.update( + conditioning_channels=conditioning_channels, + controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, + time_embedding_input_dim=get_time_emb_input_dim(unet), + time_embedding_dim=get_time_emb_dim(unet), + time_embedding_mix=time_embedding_mix, + learn_embedding=learn_embedding, + base_model_channel_sizes=ControlNetXSModel._gather_subblock_sizes(unet, base_or_control="base"), + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + ) + + return cls(**kwargs) + + def _make_zero_conv(self, in_channels, out_channels=None): + return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0)) + class ControlNetXSModel(ModelMixin, ConfigMixin): r""" @@ -490,193 +581,15 @@ def _gather_subblock_sizes(cls, unet: UNet2DConditionModel, base_or_control: str @register_to_config def __init__( self, - controlnet_conditioning_channel_order: str = "rgb", - time_embedding_input_dim: int = 320, - time_embedding_dim: int = 1280, + base_model: UNet2DConditionModel, + ctrl_model: ControlNetXSAddon, time_embedding_mix: float = 1.0, - learn_embedding: bool = False, - base_model_channel_sizes: Dict[str, List[Tuple[int]]] = { - "down": [ - (4, 320), - (320, 320), - (320, 320), - (320, 320), - (320, 640), - (640, 640), - (640, 640), - (640, 1280), - (1280, 1280), - ], - "mid": [(1280, 1280)], - "up": [ - (2560, 1280), - (2560, 1280), - (1920, 1280), - (1920, 640), - (1280, 640), - (960, 640), - (960, 320), - (640, 320), - (640, 320), - ], - }, - sample_size: Optional[int] = None, - down_block_types: Tuple[str] = ( - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "DownBlock2D", - ), - up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - norm_num_groups: Optional[int] = 32, - cross_attention_dim: Union[int, Tuple[int]] = 1280, - transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, - num_attention_heads: Optional[Union[int, Tuple[int]]] = 8, - upcast_attention: bool = False, ): super().__init__() - # 1 - Create control unet - self.control_model = UNet2DConditionModel( - sample_size=sample_size, - down_block_types=down_block_types, - up_block_types=up_block_types, - block_out_channels=block_out_channels, - norm_num_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - transformer_layers_per_block=transformer_layers_per_block, - attention_head_dim=num_attention_heads, - use_linear_projection=True, - upcast_attention=upcast_attention, - time_embedding_dim=time_embedding_dim, - ) - - @classmethod - def from_unet( - cls, - unet: UNet2DConditionModel, - conditioning_channels: int = 3, - conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), - controlnet_conditioning_channel_order: str = "rgb", - learn_embedding: bool = False, - time_embedding_mix: float = 1.0, - block_out_channels: Optional[Tuple[int]] = None, - size_ratio: Optional[float] = None, - num_attention_heads: Optional[Union[int, Tuple[int]]] = 8, - norm_num_groups: Optional[int] = None, - ): - r""" - Instantiate a [`ControlNetXSModel`] from [`UNet2DConditionModel`]. - - Parameters: - unet (`UNet2DConditionModel`): - The UNet model we want to control. The dimensions of the ControlNetXSModel will be adapted to it. - conditioning_channels (`int`, defaults to 3): - Number of channels of conditioning input (e.g. an image) - conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): - The tuple of output channel for each block in the `controlnet_cond_embedding` layer. - controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): - The channel order of conditional image. Will convert to `rgb` if it's `bgr`. - learn_embedding (`bool`, defaults to `False`): - Wether to use time embedding of the control model. If yes, the time embedding is a linear interpolation - of the time embeddings of the control and base model with interpolation parameter - `time_embedding_mix**3`. - time_embedding_mix (`float`, defaults to 1.0): - Linear interpolation parameter used if `learn_embedding` is `True`. - block_out_channels (`Tuple[int]`, *optional*): - Down blocks output channels in control model. Either this or `size_ratio` must be given. - size_ratio (float, *optional*): - When given, block_out_channels is set to a relative fraction of the base model's block_out_channels. - Either this or `block_out_channels` must be given. - num_attention_heads (`Union[int, Tuple[int]]`, *optional*): - The dimension of the attention heads. The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. - norm_num_groups (int, *optional*, defaults to `None`): - The number of groups to use for the normalization of the control unet. If `None`, - `int(unet.config.norm_num_groups * size_ratio)` is taken. - """ - - # Check input - fixed_size = block_out_channels is not None - relative_size = size_ratio is not None - if not (fixed_size ^ relative_size): - raise ValueError( - "Pass exactly one of `block_out_channels` (for absolute sizing) or `control_model_ratio` (for relative sizing)." - ) - - # Create model - if block_out_channels is None: - block_out_channels = [int(size_ratio * c) for c in unet.config.block_out_channels] - - # Check that attention heads and group norms match channel sizes - # - attention heads - def attn_heads_match_channel_sizes(attn_heads, channel_sizes): - if isinstance(attn_heads, (tuple, list)): - return all(c % a == 0 for a, c in zip(attn_heads, channel_sizes)) - else: - return all(c % attn_heads == 0 for c in channel_sizes) - - num_attention_heads = num_attention_heads or unet.config.attention_head_dim - if not attn_heads_match_channel_sizes(num_attention_heads, block_out_channels): - raise ValueError( - f"The dimension of attention heads ({num_attention_heads}) must divide `block_out_channels` ({block_out_channels}). If you didn't set `num_attention_heads` the default settings don't match your model. Set `num_attention_heads` manually." - ) - - # - group norms - def group_norms_match_channel_sizes(num_groups, channel_sizes): - return all(c % num_groups == 0 for c in channel_sizes) - - if norm_num_groups is None: - if group_norms_match_channel_sizes(unet.config.norm_num_groups, block_out_channels): - norm_num_groups = unet.config.norm_num_groups - else: - norm_num_groups = min(block_out_channels) - - if group_norms_match_channel_sizes(norm_num_groups, block_out_channels): - print( - f"`norm_num_groups` was set to `min(block_out_channels)` (={norm_num_groups}) so it divides all block_out_channels` ({block_out_channels}). Set it explicitly to remove this information." - ) - else: - raise ValueError( - f"`block_out_channels` ({block_out_channels}) don't match the base models `norm_num_groups` ({unet.config.norm_num_groups}). Setting `norm_num_groups` to `min(block_out_channels)` ({norm_num_groups}) didn't fix this. Pass `norm_num_groups` explicitly so it divides all block_out_channels." - ) - - def get_time_emb_input_dim(unet: UNet2DConditionModel): - return unet.time_embedding.linear_1.in_features - - def get_time_emb_dim(unet: UNet2DConditionModel): - return unet.time_embedding.linear_2.out_features - - # Clone params from base unet if - # (i) it's required to build SD or SDXL, and - # (ii) it's not used for the time embedding (as time embedding of control model is never used), and - # (iii) it's not set further below anyway - to_keep = [ - "cross_attention_dim", - "down_block_types", - "sample_size", - "transformer_layers_per_block", - "up_block_types", - "upcast_attention", - ] - kwargs = {k: v for k, v in dict(unet.config).items() if k in to_keep} - kwargs.update(block_out_channels=block_out_channels) - kwargs.update(num_attention_heads=num_attention_heads) - kwargs.update(norm_num_groups=norm_num_groups) - - # Add controlnetxs-specific params - kwargs.update( - conditioning_channels=conditioning_channels, - controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, - time_embedding_input_dim=get_time_emb_input_dim(unet), - time_embedding_dim=get_time_emb_dim(unet), - time_embedding_mix=time_embedding_mix, - learn_embedding=learn_embedding, - base_model_channel_sizes=ControlNetXSModel._gather_subblock_sizes(unet, base_or_control="base"), - conditioning_embedding_out_channels=conditioning_embedding_out_channels, - ) - - return cls(**kwargs) + self.base_model = base_model + self.ctrl_model = ctrl_model + self.time_embedding_mix = time_embedding_mix def forward( self, @@ -870,7 +783,7 @@ def forward( for m_base, m_ctrl in zip(base_mid_subblocks, ctrl_mid_subblocks): h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock h_ctrl = m_ctrl(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock - h_base = h_base + self.middle_block_out(h_ctrl) * next(scales) # D - add ctrl -> base + h_base = h_base + self.middle_zero_convs_out(h_ctrl) * next(scales) # D - add ctrl -> base # 3 - up for i, m_base in enumerate(base_up_subblocks): @@ -887,77 +800,6 @@ def forward( return ControlNetXSOutput(sample=h_base) - def _make_zero_conv(self, in_channels, out_channels=None): - # keep running track of channels sizes - self.in_channels = in_channels - self.out_channels = out_channels or in_channels - - return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0)) - - -def increase_block_input_in_mid_resnet(unet: UNet2DConditionModel, by): - """Increase channels sizes to allow for additional concatted information from base model""" - m = unet.mid_block.resnets[0] - old_norm1, old_conv1 = m.norm1, m.conv1 - # norm - norm_args = "num_groups num_channels eps affine".split(" ") - for a in norm_args: - assert hasattr(old_norm1, a) - norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args} - norm_kwargs["num_channels"] += by # surgery done here - # conv1 - conv1_args = ( - "in_channels out_channels kernel_size stride padding dilation groups bias padding_mode lora_layer".split(" ") - ) - for a in conv1_args: - assert hasattr(old_conv1, a) - conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args} - conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor. - conv1_kwargs["in_channels"] += by # surgery done here - # conv_shortcut - # as we changed the input size of the block, the input and output sizes are likely different, - # therefore we need a conv_shortcut (simply adding won't work) - conv_shortcut_args_kwargs = { - "in_channels": conv1_kwargs["in_channels"], - "out_channels": conv1_kwargs["out_channels"], - # default arguments from resnet.__init__ - "kernel_size": 1, - "stride": 1, - "padding": 0, - "bias": True, - } - # swap old with new modules - unet.mid_block.resnets[0].norm1 = GroupNorm(**norm_kwargs) - unet.mid_block.resnets[0].conv1 = LoRACompatibleConv(**conv1_kwargs) - unet.mid_block.resnets[0].conv_shortcut = LoRACompatibleConv(**conv_shortcut_args_kwargs) - unet.mid_block.resnets[0].in_channels += by # surgery done here - - -def adjust_group_norms(unet: UNet2DConditionModel, max_num_group: int = 32): - def find_denominator(number, start): - if start >= number: - return number - while start != 0: - residual = number % start - if residual == 0: - return start - start -= 1 - - for block in [*unet.down_blocks, unet.mid_block]: - # resnets - for r in block.resnets: - if r.norm1.num_groups < max_num_group: - r.norm1.num_groups = find_denominator(r.norm1.num_channels, start=max_num_group) - - if r.norm2.num_groups < max_num_group: - r.norm2.num_groups = find_denominator(r.norm2.num_channels, start=max_num_group) - - # transformers - if hasattr(block, "attentions"): - for a in block.attentions: - if a.norm.num_groups < max_num_group: - a.norm.num_groups = find_denominator(a.norm.num_channels, start=max_num_group) - def zero_module(module): for p in module.parameters(): @@ -975,6 +817,7 @@ def __init__( num_attention_heads: Optional[int] = 1, cross_attention_dim: Optional[int] = 1024, upcast_attention: Optional[bool] = False, + norm_num_groups: int = 32, has_resnet = False, has_crossattn = False, has_downsampler = False, @@ -983,12 +826,15 @@ def __init__( self.has_cross_attention = True self.num_attention_heads = num_attention_heads + self.in_channels = in_channels + self.out_channels = out_channels if has_resnet: self.resnet = ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, + groups=norm_num_groups, eps=1e-5, ) else: @@ -1003,6 +849,7 @@ def __init__( cross_attention_dim=cross_attention_dim, use_linear_projection=True, upcast_attention=upcast_attention, + norm_num_groups=norm_num_groups, ) else: self.attention = None From b45de06a6d7a97470f7b928550050ea0cf8c8ed7 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Wed, 10 Jan 2024 18:29:36 +0100 Subject: [PATCH 04/75] checkin --- src/diffusers/models/controlnet_xs.py | 364 ++++++++++++++++++-------- 1 file changed, 259 insertions(+), 105 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 771e20cd833d..19de913177bd 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -27,6 +27,7 @@ ResnetBlock2D, Transformer2DModel, UNetMidBlock2DCrossAttn, + Upsample2D ) from .unet_2d_condition import UNet2DConditionModel @@ -135,11 +136,12 @@ def init_original(cls, sd_type): @register_to_config def __init__( self, + conditioning_channel_order: str = 'rgb', conditioning_channels: int = 3, conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), time_embedding_input_dim: int = 320, time_embedding_dim: int = 1280, - learn_embedding: bool = False, + learn_time_embedding: bool = False, base_model_channel_sizes: Dict[str, List[Tuple[int]]] = { "down - in": [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280], "down - out": [320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280], @@ -161,8 +163,6 @@ def __init__( ): super().__init__() - # todo: learn_embedding - self.sample_size = sample_size # `num_attention_heads` defaults to `attention_head_dim`. This looks weird upon first reading it and it is. @@ -173,16 +173,18 @@ def __init__( num_attention_heads = attention_head_dim # Check inputs - # todo + if conditioning_channel_order not in ["rgb", "bgr"]: + raise ValueError(f"unknown `conditioning_channel_order`: {conditioning_channel_order}") + # todo - other checks # input self.conv_in = nn.Conv2d(4, block_out_channels[0], kernel_size=3, padding=1) # time - time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0) - # note umer: here `time_embedding_input_dim` is used, so time info can be received from base model - self.time_embedding = TimestepEmbedding(time_embedding_input_dim, time_embed_dim) + if learn_time_embedding: + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedding = TimestepEmbedding(time_embedding_input_dim, time_embed_dim) if addition_embed_type == "text_time": self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0) @@ -223,8 +225,7 @@ def get_extra_channel(block_no, subblock_no): output_channel = block_out_channels[i] use_crossattention = down_block_type == "CrossAttnDownBlock2D" - self.down_subblocks.append(DownSubBlock2D( - has_resnet=True, + self.down_subblocks.append(CrossAttnSubBlock2D( has_crossattn=use_crossattention, in_channels=input_channel + get_extra_channel(block_no=i, subblock_no=0), out_channels=output_channel, @@ -235,8 +236,7 @@ def get_extra_channel(block_no, subblock_no): upcast_attention=upcast_attention, norm_num_groups=norm_num_groups, )) - self.down_subblocks.append(DownSubBlock2D( - has_resnet=True, + self.down_subblocks.append(CrossAttnSubBlock2D( has_crossattn=use_crossattention, in_channels=output_channel + get_extra_channel(block_no=i, subblock_no=1), out_channels=output_channel, @@ -248,8 +248,7 @@ def get_extra_channel(block_no, subblock_no): norm_num_groups=norm_num_groups, )) if i 0") @@ -710,32 +757,27 @@ def forward( class_emb = base_model.class_embedding(class_labels).to(dtype=self.dtype) temb = temb + class_emb - if base_model.config.addition_embed_type is not None: - if base_model.config.addition_embed_type == "text": - aug_emb = base_model.add_embedding(encoder_hidden_states) - elif base_model.config.addition_embed_type == "text_image": - raise NotImplementedError() - elif base_model.config.addition_embed_type == "text_time": - # SDXL - style - if "text_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" - ) - text_embeds = added_cond_kwargs.get("text_embeds") - if "time_ids" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" - ) - time_ids = added_cond_kwargs.get("time_ids") - time_embeds = base_model.add_time_proj(time_ids.flatten()) - time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) - add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) - add_embeds = add_embeds.to(temb.dtype) - aug_emb = base_model.add_embedding(add_embeds) - elif base_model.config.addition_embed_type == "image": - raise NotImplementedError() - elif base_model.config.addition_embed_type == "image_hint": - raise NotImplementedError() + if self.base_addition_embed_type is None: + pass + elif self.base_addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.base_add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(temb.dtype) + aug_emb = self.base_add_embedding(add_embeds) + else: + raise NotImplementedError() temb = temb + aug_emb if aug_emb is not None else temb @@ -747,21 +789,11 @@ def forward( h_ctrl = h_base = sample hs_base, hs_ctrl = [], [] - it_down_convs_in, it_down_convs_out, it_dec_convs_in, it_up_convs_out = map( - iter, (self.down_zero_convs_in, self.down_zero_convs_out, self.up_zero_convs_in, self.up_zero_convs_out) - ) - scales = iter(scale_list) - - base_down_subblocks = to_sub_blocks(base_model.down_blocks) - ctrl_down_subblocks = to_sub_blocks(self.control_model.down_blocks) - base_mid_subblocks = to_sub_blocks([base_model.mid_block]) - ctrl_mid_subblocks = to_sub_blocks([self.control_model.mid_block]) - base_up_subblocks = to_sub_blocks(base_model.up_blocks) # Cross Control # 0 - conv in - h_base = base_model.conv_in(h_base) - h_ctrl = self.control_model.conv_in(h_ctrl) + h_base = self.base_conv_in(h_base) + h_ctrl = self.ctrl_conv_in(h_ctrl) if guided_hint is not None: h_ctrl += guided_hint h_base = h_base + next(it_down_convs_out)(h_ctrl) * next(scales) # D - add ctrl -> base @@ -770,30 +802,35 @@ def forward( hs_ctrl.append(h_ctrl) # 1 - down - for m_base, m_ctrl in zip(base_down_subblocks, ctrl_down_subblocks): + for m_base, m_ctrl in zip(self.base_down_subblocks, self.ctrl_down_subblocks): + if isinstance(m_base, CrossAttnSubBlock2D): + additional_params = [temb, cemb, attention_mask, cross_attention_kwargs] + else: + additional_params = [] + h_ctrl = torch.cat([h_ctrl, next(it_down_convs_in)(h_base)], dim=1) # A - concat base -> ctrl - h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock - h_ctrl = m_ctrl(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock + h_base = m_base(h_base, *additional_params) # B - apply base subblock + h_ctrl = m_ctrl(h_ctrl, *additional_params) # C - apply ctrl subblock h_base = h_base + next(it_down_convs_out)(h_ctrl) * next(scales) # D - add ctrl -> base hs_base.append(h_base) hs_ctrl.append(h_ctrl) # 2 - mid h_ctrl = torch.cat([h_ctrl, next(it_down_convs_in)(h_base)], dim=1) # A - concat base -> ctrl - for m_base, m_ctrl in zip(base_mid_subblocks, ctrl_mid_subblocks): + for m_base, m_ctrl in zip(self.base_mid_block, self.ctrl_mid_block): h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock h_ctrl = m_ctrl(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock h_base = h_base + self.middle_zero_convs_out(h_ctrl) * next(scales) # D - add ctrl -> base # 3 - up - for i, m_base in enumerate(base_up_subblocks): + for m_base in self.base_up_subblocks: h_base = h_base + next(it_up_convs_out)(hs_ctrl.pop()) * next(scales) # add info from ctrl encoder h_base = torch.cat([h_base, hs_base.pop()], dim=1) # concat info from base encoder+ctrl encoder h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) - h_base = base_model.conv_norm_out(h_base) - h_base = base_model.conv_act(h_base) - h_base = base_model.conv_out(h_base) + h_base = self.base_conv_norm_out(h_base) + h_base = self.base_conv_act(h_base) + h_base = self.base_conv_out(h_base) if not return_dict: return h_base @@ -807,38 +844,36 @@ def zero_module(module): return module -class DownSubBlock2D(nn.Module): +class CrossAttnSubBlock2D(nn.Module): def __init__( self, - in_channels: int, - out_channels: int, + is_empty: bool = False, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, temb_channels: Optional[int] = None, + norm_num_groups: Optional[int] = 32, + has_crossattn = False, transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1, num_attention_heads: Optional[int] = 1, cross_attention_dim: Optional[int] = 1024, upcast_attention: Optional[bool] = False, - norm_num_groups: int = 32, - has_resnet = False, - has_crossattn = False, - has_downsampler = False, ): super().__init__() + self.gradient_checkpointing = False + + if is_empty: + return - self.has_cross_attention = True - self.num_attention_heads = num_attention_heads self.in_channels = in_channels self.out_channels = out_channels - if has_resnet: - self.resnet = ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - groups=norm_num_groups, - eps=1e-5, - ) - else: - self.resnet = None + self.resnet = ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + groups=norm_num_groups, + eps=1e-5, + ) if has_crossattn: self.attention = Transformer2DModel( @@ -854,12 +889,15 @@ def __init__( else: self.attention = None - if has_downsampler: - self.downsampler = Downsample2D(out_channels, use_conv=True, out_channels=out_channels, name="op") - else: - self.downsampler = None - - self.gradient_checkpointing = False + @classmethod + def from_modules(cls, resnet: ResnetBlock2D, attention: Optional[Transformer2DModel] = None): + """Create empty subblock and set resnet and attention manually""" + subblock = cls(is_empty=True) + subblock.resnet = resnet + subblock.attention = attention + subblock.in_channels = resnet.in_channels + subblock.out_channels = resnet.out_channels + return subblock def forward( self, @@ -899,8 +937,6 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - if self.downsampler is not None: - hidden_states = self.downsampler(hidden_states) else: if self.resnet is not None: hidden_states = self.resnet(hidden_states, temb, scale=lora_scale) @@ -913,7 +949,125 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - if self.downsampler is not None: - hidden_states = self.downsampler(hidden_states) return hidden_states + + +class DownSubBlock2D(nn.Module): + def __init__( + self, + is_empty: bool = False, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + ): + super().__init__() + self.gradient_checkpointing = False + + if is_empty: + return + + self.in_channels = in_channels + self.out_channels = out_channels + + self.downsampler = Downsample2D(out_channels, use_conv=True, out_channels=out_channels, name="op") + + @classmethod + def from_modules(cls, downsampler: Downsample2D): + """Create empty subblock and set downsampler manually""" + subblock = cls(is_empty=True) + subblock.downsampler = downsampler + subblock.in_channels = downsampler.channels + subblock.out_channels = downsampler.out_channels + return subblock + + def forward( + self, + hidden_states: torch.FloatTensor, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + if self.training and self.gradient_checkpointing: + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + # todo: gradient ckptin? + hidden_states = self.downsampler(hidden_states) + else: + hidden_states = self.downsampler(hidden_states) + + return hidden_states + + +class CrossAttnUpSubBlock2D(nn.Module): + def __init__(self): + """todo doc - init emtpty as only from_modules will be used""" + super().__init__() + self.gradient_checkpointing = False + + @classmethod + def from_modules(cls, resnet: ResnetBlock2D, attention: Optional[Transformer2DModel] = None, upsampler: Optional[Upsample2D] = None): + """Create empty subblock and set resnet, attention and upsampler manually""" + subblock = cls(is_empty=True) + subblock.resnet = resnet + subblock.attention = attention + subblock.upsampler = upsampler + subblock.in_channels = resnet.in_channels + subblock.out_channels = resnet.out_channels + return subblock + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + if self.training and self.gradient_checkpointing: + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + if self.attention is not None: + hidden_states = self.attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = self.upsampler(hidden_states) + else: + hidden_states = self.resnet(hidden_states, temb, scale=lora_scale) + if self.attention is not None: + hidden_states = self.attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = self.upsampler(hidden_states) + + return hidden_states \ No newline at end of file From eb9c59efde8e56570191e10e3c826b02cb1ce25f Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Wed, 10 Jan 2024 23:29:02 +0100 Subject: [PATCH 05/75] Removed iter, next in forward --- src/diffusers/models/controlnet_xs.py | 35 +++++++++++++++------------ 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 19de913177bd..ae433d2ec0d8 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -594,7 +594,13 @@ def __init__( self.ctrl_down_subblocks = ctrl_model.down_subblocks self.ctrl_mid_block = ctrl_model.mid_block - # 2 - Save base model parts + # 2 - Save connections + self.down_zero_convs_in = ctrl_model.down_zero_convs_in + self.down_zero_convs_out = ctrl_model.down_zero_convs_out + self.middle_zero_convs_out = ctrl_model.middle_zero_convs_out + self.up_zero_convs_out = ctrl_model.up_zero_convs_out + + # 3 - Save base model parts self.base_time_proj = base_model.time_proj self.base_time_embedding = base_model.time_embedding self.base_class_embedding = base_model.class_embedding @@ -606,7 +612,7 @@ def __init__( self.base_mid_block = base_model.mid_block self.base_up_subblocks = nn.ModuleList() - # 2.1 - Decompose blocks of base model into subblocks + # 3.1 - Decompose blocks of base model into subblocks for block in base_model.down_blocks: # Each ResNet / Attention pair is a subblock resnets = block.resnets @@ -704,7 +710,6 @@ def forward( # scale control strength n_connections = len(self.down_zero_convs_out) + 1 + len(self.up_zero_convs_out) - scale_list = torch.full((n_connections,), conditioning_scale) # prepare attention_mask if attention_mask is not None: @@ -796,37 +801,37 @@ def forward( h_ctrl = self.ctrl_conv_in(h_ctrl) if guided_hint is not None: h_ctrl += guided_hint - h_base = h_base + next(it_down_convs_out)(h_ctrl) * next(scales) # D - add ctrl -> base + h_base = h_base + self.down_zero_convs_out[0](h_ctrl) * conditioning_scale # D - add ctrl -> base hs_base.append(h_base) hs_ctrl.append(h_ctrl) # 1 - down - for m_base, m_ctrl in zip(self.base_down_subblocks, self.ctrl_down_subblocks): - if isinstance(m_base, CrossAttnSubBlock2D): + for b, c, b2c, c2b in zip(self.base_down_subblocks, self.ctrl_down_subblocks, self.down_zero_convs_in[:-1], self.down_zero_convs_out[1:]): + if isinstance(b, CrossAttnSubBlock2D): additional_params = [temb, cemb, attention_mask, cross_attention_kwargs] else: additional_params = [] - h_ctrl = torch.cat([h_ctrl, next(it_down_convs_in)(h_base)], dim=1) # A - concat base -> ctrl - h_base = m_base(h_base, *additional_params) # B - apply base subblock - h_ctrl = m_ctrl(h_ctrl, *additional_params) # C - apply ctrl subblock - h_base = h_base + next(it_down_convs_out)(h_ctrl) * next(scales) # D - add ctrl -> base + h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) # A - concat base -> ctrl + h_base = b(h_base, *additional_params) # B - apply base subblock + h_ctrl = c(h_ctrl, *additional_params) # C - apply ctrl subblock + h_base = h_base + c2b(h_ctrl) * conditioning_scale # D - add ctrl -> base hs_base.append(h_base) hs_ctrl.append(h_ctrl) # 2 - mid - h_ctrl = torch.cat([h_ctrl, next(it_down_convs_in)(h_base)], dim=1) # A - concat base -> ctrl + h_ctrl = torch.cat([h_ctrl, self.down_zero_convs_in[-1](h_base)], dim=1) # A - concat base -> ctrl for m_base, m_ctrl in zip(self.base_mid_block, self.ctrl_mid_block): h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock h_ctrl = m_ctrl(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock - h_base = h_base + self.middle_zero_convs_out(h_ctrl) * next(scales) # D - add ctrl -> base + h_base = h_base + self.middle_zero_convs_out(h_ctrl) * conditioning_scale # D - add ctrl -> base # 3 - up - for m_base in self.base_up_subblocks: - h_base = h_base + next(it_up_convs_out)(hs_ctrl.pop()) * next(scales) # add info from ctrl encoder + for b, c2b in zip(self.base_up_subblocks, self.up_zero_convs_out): + h_base = h_base + c2b(hs_ctrl.pop()) * conditioning_scale # add info from ctrl encoder h_base = torch.cat([h_base, hs_base.pop()], dim=1) # concat info from base encoder+ctrl encoder - h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) + h_base = b(h_base, temb, cemb, attention_mask, cross_attention_kwargs) h_base = self.base_conv_norm_out(h_base) h_base = self.base_conv_act(h_base) From 32bf5a7507df57e4bf294952dadfe7f1f581ded5 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Thu, 11 Jan 2024 10:23:06 +0100 Subject: [PATCH 06/75] Models for SD21 & SDXL run through --- src/diffusers/models/controlnet_xs.py | 100 +++++++++++++------------ src/diffusers/models/unet_2d_blocks.py | 17 +++-- 2 files changed, 61 insertions(+), 56 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index ae433d2ec0d8..899e9b5b1c60 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -6,15 +6,12 @@ import torch.utils.checkpoint from torch import nn from torch.nn import functional as F -from torch.nn.modules.normalization import GroupNorm from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging, is_torch_version from .attention_processor import ( AttentionProcessor, ) -from .autoencoders import AutoencoderKL -from .lora import LoRACompatibleConv from .embeddings import ( TimestepEmbedding, Timesteps, @@ -145,7 +142,7 @@ def __init__( base_model_channel_sizes: Dict[str, List[Tuple[int]]] = { "down - in": [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280], "down - out": [320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280], - "mid": 1280, + "mid - out": 1280, "up - in": [1280, 1280, 1280, 1280,1280, 1280, 1280, 640, 640, 640, 320, 320], }, addition_embed_type = None, @@ -182,9 +179,12 @@ def __init__( # time if learn_time_embedding: - time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + time_embedding_dim = time_embedding_dim or block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0) - self.time_embedding = TimestepEmbedding(time_embedding_input_dim, time_embed_dim) + self.time_embedding = TimestepEmbedding(time_embedding_input_dim, time_embedding_dim) + else: + self.time_proj = None + self.time_embedding = None if addition_embed_type == "text_time": self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0) @@ -203,8 +203,6 @@ def __init__( if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) - blocks_time_embed_dim = time_embed_dim - # down def get_extra_channel(block_no, subblock_no): """Determine channel size for extra info from base - todo""" @@ -229,7 +227,7 @@ def get_extra_channel(block_no, subblock_no): has_crossattn=use_crossattention, in_channels=input_channel + get_extra_channel(block_no=i, subblock_no=0), out_channels=output_channel, - temb_channels=blocks_time_embed_dim, + temb_channels=time_embedding_dim, transformer_layers_per_block=transformer_layers_per_block[i], num_attention_heads=num_attention_heads[i], cross_attention_dim=cross_attention_dim, @@ -240,7 +238,7 @@ def get_extra_channel(block_no, subblock_no): has_crossattn=use_crossattention, in_channels=output_channel + get_extra_channel(block_no=i, subblock_no=1), out_channels=output_channel, - temb_channels=blocks_time_embed_dim, + temb_channels=time_embedding_dim, transformer_layers_per_block=transformer_layers_per_block[i], num_attention_heads=num_attention_heads[i], cross_attention_dim=cross_attention_dim, @@ -248,28 +246,24 @@ def get_extra_channel(block_no, subblock_no): norm_num_groups=norm_num_groups, )) if i ctrl - for m_base, m_ctrl in zip(self.base_mid_block, self.ctrl_mid_block): - h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock - h_ctrl = m_ctrl(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock + h_base = self.base_mid_block(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock + h_ctrl = self.ctrl_mid_block(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock h_base = h_base + self.middle_zero_convs_out(h_ctrl) * conditioning_scale # D - add ctrl -> base # 3 - up - for b, c2b in zip(self.base_up_subblocks, self.up_zero_convs_out): - h_base = h_base + c2b(hs_ctrl.pop()) * conditioning_scale # add info from ctrl encoder - h_base = torch.cat([h_base, hs_base.pop()], dim=1) # concat info from base encoder+ctrl encoder + for b, c2b, skip_c, skip_b in zip(self.base_up_subblocks, self.up_zero_convs_out, reversed(hs_ctrl), reversed(hs_base)): + h_base = h_base + c2b(skip_c) * conditioning_scale # add info from ctrl encoder + h_base = torch.cat([h_base, skip_b], dim=1) # concat info from base encoder+ctrl encoder h_base = b(h_base, temb, cemb, attention_mask, cross_attention_kwargs) h_base = self.base_conv_norm_out(h_base) @@ -934,7 +934,7 @@ def custom_forward(*inputs): **ckpt_kwargs, ) if self.attention is not None: - hidden_states = self.attn( + hidden_states = self.attention( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, @@ -946,7 +946,7 @@ def custom_forward(*inputs): if self.resnet is not None: hidden_states = self.resnet(hidden_states, temb, scale=lora_scale) if self.attention is not None: - hidden_states = self.attn( + hidden_states = self.attention( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, @@ -974,7 +974,7 @@ def __init__( self.in_channels = in_channels self.out_channels = out_channels - self.downsampler = Downsample2D(out_channels, use_conv=True, out_channels=out_channels, name="op") + self.downsampler = Downsample2D(in_channels, use_conv=True, out_channels=out_channels, name="op") @classmethod def from_modules(cls, downsampler: Downsample2D): @@ -1016,7 +1016,7 @@ def __init__(self): @classmethod def from_modules(cls, resnet: ResnetBlock2D, attention: Optional[Transformer2DModel] = None, upsampler: Optional[Upsample2D] = None): """Create empty subblock and set resnet, attention and upsampler manually""" - subblock = cls(is_empty=True) + subblock = cls() subblock.resnet = resnet subblock.attention = attention subblock.upsampler = upsampler @@ -1053,7 +1053,7 @@ def custom_forward(*inputs): **ckpt_kwargs, ) if self.attention is not None: - hidden_states = self.attn( + hidden_states = self.attention( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, @@ -1061,11 +1061,12 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - hidden_states = self.upsampler(hidden_states) + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states) else: hidden_states = self.resnet(hidden_states, temb, scale=lora_scale) if self.attention is not None: - hidden_states = self.attn( + hidden_states = self.attention( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, @@ -1073,6 +1074,7 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - hidden_states = self.upsampler(hidden_states) + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states) return hidden_states \ No newline at end of file diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index e404cef224ff..3614fc06f301 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -632,6 +632,7 @@ def __init__( self, in_channels: int, temb_channels: int, + out_channels: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1, @@ -650,6 +651,8 @@ def __init__( ): super().__init__() + out_channels = out_channels or in_channels + self.has_cross_attention = True self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) @@ -662,7 +665,7 @@ def __init__( resnets = [ ResnetBlock2D( in_channels=in_channels, - out_channels=in_channels, + out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, @@ -680,8 +683,8 @@ def __init__( attentions.append( Transformer2DModel( num_attention_heads, - in_channels // num_attention_heads, - in_channels=in_channels, + out_channels // num_attention_heads, + in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, @@ -694,8 +697,8 @@ def __init__( attentions.append( DualTransformer2DModel( num_attention_heads, - in_channels // num_attention_heads, - in_channels=in_channels, + out_channels // num_attention_heads, + in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, @@ -703,8 +706,8 @@ def __init__( ) resnets.append( ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, + in_channels=out_channels, + out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, From b58572bc053e8af60e44b2542cb65d21bc8113eb Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Fri, 12 Jan 2024 10:58:47 +0100 Subject: [PATCH 07/75] Added back pipelines, cleared up connections --- src/diffusers/__init__.py | 8 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/controlnet_xs.py | 537 +++----- src/diffusers/pipelines/__init__.py | 10 + .../pipelines/controlnet_xs/__init__.py | 68 + .../controlnet_xs/pipeline_controlnet_xs.py | 946 ++++++++++++++ .../pipeline_controlnet_xs_sd_xl.py | 1120 +++++++++++++++++ 7 files changed, 2347 insertions(+), 344 deletions(-) create mode 100644 src/diffusers/pipelines/controlnet_xs/__init__.py create mode 100644 src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py create mode 100644 src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 180b210953c1..83eff642be93 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -80,6 +80,8 @@ "AutoencoderTiny", "ConsistencyDecoderVAE", "ControlNetModel", + "ControlNetXSAddon", + "ControlNetXSModel", "Kandinsky3UNet", "ModelMixin", "MotionAdapter", @@ -255,6 +257,7 @@ "StableDiffusionControlNetImg2ImgPipeline", "StableDiffusionControlNetInpaintPipeline", "StableDiffusionControlNetPipeline", + "StableDiffusionControlNetXSPipeline", "StableDiffusionDepth2ImgPipeline", "StableDiffusionDiffEditPipeline", "StableDiffusionGLIGENPipeline", @@ -278,6 +281,7 @@ "StableDiffusionXLControlNetImg2ImgPipeline", "StableDiffusionXLControlNetInpaintPipeline", "StableDiffusionXLControlNetPipeline", + "StableDiffusionXLControlNetXSPipeline", "StableDiffusionXLImg2ImgPipeline", "StableDiffusionXLInpaintPipeline", "StableDiffusionXLInstructPix2PixPipeline", @@ -459,6 +463,8 @@ AutoencoderTiny, ConsistencyDecoderVAE, ControlNetModel, + ControlNetXSAddon, + ControlNetXSModel, Kandinsky3UNet, ModelMixin, MotionAdapter, @@ -613,6 +619,7 @@ StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline, StableDiffusionControlNetPipeline, + StableDiffusionControlNetXSPipeline, StableDiffusionDepth2ImgPipeline, StableDiffusionDiffEditPipeline, StableDiffusionGLIGENPipeline, @@ -636,6 +643,7 @@ StableDiffusionXLControlNetImg2ImgPipeline, StableDiffusionXLControlNetInpaintPipeline, StableDiffusionXLControlNetPipeline, + StableDiffusionXLControlNetXSPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLInstructPix2PixPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 36dbe14c5053..71e309d97bbe 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -32,6 +32,7 @@ _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["controlnet"] = ["ControlNetModel"] + _import_structure["controlnet_xs"] = ["ControlNetXSAddon", "ControlNetXSModel"] _import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"] _import_structure["embeddings"] = ["ImageProjection"] _import_structure["modeling_utils"] = ["ModelMixin"] @@ -66,6 +67,7 @@ ConsistencyDecoderVAE, ) from .controlnet import ControlNetModel + from .controlnet_xs import ControlNetXSAddon, ControlNetXSModel from .dual_transformer_2d import DualTransformer2DModel from .embeddings import ImageProjection from .modeling_utils import ModelMixin diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 899e9b5b1c60..b1e6c5c2105a 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -8,24 +8,14 @@ from torch.nn import functional as F from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, logging, is_torch_version -from .attention_processor import ( - AttentionProcessor, -) +from ..utils import BaseOutput, is_torch_version, logging +from .autoencoders import AutoencoderKL from .embeddings import ( TimestepEmbedding, Timesteps, ) from .modeling_utils import ModelMixin -from .unet_2d_blocks import ( - CrossAttnDownBlock2D, - DownBlock2D, - Downsample2D, - ResnetBlock2D, - Transformer2DModel, - UNetMidBlock2DCrossAttn, - Upsample2D -) +from .unet_2d_blocks import Downsample2D, ResnetBlock2D, Transformer2DModel, UNetMidBlock2DCrossAttn, Upsample2D from .unet_2d_condition import UNet2DConditionModel @@ -94,69 +84,81 @@ def forward(self, conditioning): class ControlNetXSAddon(ModelMixin, ConfigMixin): @classmethod - def init_original(cls, sd_type): - # todo - kwargs = {} - if sd_type == "sdxl": - kwargs.update({ - 'addition_embed_type': "text_time", - 'addition_time_embed_dim': 256, - 'attention_head_dim': [5, 10, 20], - 'block_out_channels': [320, 640, 1280], - 'cross_attention_dim': 2048, - 'down_block_types': ['DownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D'], - 'projection_class_embeddings_input_dim': 2816, - 'sample_size': 128, - 'transformer_layers_per_block': [1, 2, 10], - 'up_block_types': ['CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'UpBlock2D'], - 'upcast_attention': None, - }) - elif sd_type == "sd": - kwargs.update({ - 'addition_embed_type': None, - 'addition_time_embed_dim': None, - 'attention_head_dim': [5, 10, 20, 20], - 'block_out_channels': [320, 640, 1280, 1280], - 'cross_attention_dim': 1024, - 'down_block_types': ['CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'], - 'projection_class_embeddings_input_dim': None, - 'sample_size': 96, - 'transformer_layers_per_block': 1, - 'up_block_types': ['UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'], - 'upcast_attention': True - }) - else: - raise ValueError("`sd_type` needs to either 'sd' or 'sdxl'") + def from_unet( + cls, + base_model: UNet2DConditionModel, + size_ratio: Optional[float] = None, + block_out_channels: Optional[List[int]] = None, + num_attention_heads: Optional[List[int]] = None, + learn_time_embedding: bool = False, + ): + # todo - comment - return ControlNetXSAddon(**kwargs) + # Check input + fixed_size = block_out_channels is not None + relative_size = size_ratio is not None + if not (fixed_size ^ relative_size): + raise ValueError( + "Pass exactly one of `block_out_channels` (for absolute sizing) or `control_model_ratio` (for relative sizing)." + ) + + channels_base = { # todo + "down - in": [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280], + "down - out": [320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280], + "mid - out": 1280, + "up - in": [1280, 1280, 1280, 1280, 1280, 1280, 1280, 640, 640, 640, 320, 320], + } + + block_out_channels = [int(b * size_ratio) for b in base_model.config.block_out_channels] + if num_attention_heads is None: + num_attention_heads = base_model.config.num_attention_heads + + norm_num_groups = math.gcd(*block_out_channels) + + return ControlNetXSAddon( + learn_time_embedding=learn_time_embedding, + channels_base=channels_base, + addition_embed_type=base_model.config.addition_embed_type, + addition_time_embed_dim=base_model.config.addition_time_embed_dim, + attention_head_dim=num_attention_heads, + block_out_channels=block_out_channels, + base_block_out_channels=base_model.config.block_out_channels, + cross_attention_dim=base_model.config.cross_attention_dim, + down_block_types=base_model.config.down_block_types, + projection_class_embeddings_input_dim=base_model.config.projection_class_embeddings_input_dim, + sample_size=base_model.config.sample_size, + transformer_layers_per_block=base_model.config.transformer_layers_per_block, + upcast_attention=base_model.config.upcast_attention, + norm_num_groups=norm_num_groups, + ) @register_to_config def __init__( self, - conditioning_channel_order: str = 'rgb', + conditioning_channel_order: str = "rgb", conditioning_channels: int = 3, conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), time_embedding_input_dim: int = 320, time_embedding_dim: int = 1280, learn_time_embedding: bool = False, - base_model_channel_sizes: Dict[str, List[Tuple[int]]] = { - "down - in": [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280], - "down - out": [320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280], + channels_base: Dict[str, List[Tuple[int]]] = { + "down - in": [320, 320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280], + "down - out": [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280], "mid - out": 1280, - "up - in": [1280, 1280, 1280, 1280,1280, 1280, 1280, 640, 640, 640, 320, 320], + "up - in": [1280, 1280, 1280, 1280, 1280, 1280, 1280, 640, 640, 640, 320, 320], }, - addition_embed_type = None, - addition_time_embed_dim = None, - attention_head_dim = [4], - block_out_channels = [4, 8, 16, 16], - base_block_out_channels = [320, 640, 1280, 1280], - cross_attention_dim = 1024, - down_block_types = ['CrossAttnDownBlock2D', 'CrossAttnDownBlock2D','CrossAttnDownBlock2D', 'DownBlock2D'], - projection_class_embeddings_input_dim = None, - sample_size = 96, + addition_embed_type=None, + addition_time_embed_dim=None, + attention_head_dim=[4], + block_out_channels=[4, 8, 16, 16], + base_block_out_channels=[320, 640, 1280, 1280], + cross_attention_dim=1024, + down_block_types=["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"], + projection_class_embeddings_input_dim=None, + sample_size=96, transformer_layers_per_block: Union[int, Tuple[int]] = 1, - upcast_attention = True, - norm_num_groups = 4, + upcast_attention=True, + norm_num_groups=4, ): super().__init__() @@ -188,12 +190,12 @@ def __init__( if addition_embed_type == "text_time": self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0) - self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embedding_dim) elif addition_embed_type is not None: raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") self.time_embed_act = None - + self.down_subblocks = nn.ModuleList([]) self.up_subblocks = nn.ModuleList([]) @@ -204,17 +206,17 @@ def __init__( transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) # down - def get_extra_channel(block_no, subblock_no): - """Determine channel size for extra info from base - todo""" - if block_no==0: - # in 1st block: all same - todo + def channels_from_base(block_no, subblock_no): + """Determine channel size for extra info from base model""" + if block_no == 0: + # in 1st block: all subblocks have same channels return base_block_out_channels[0] else: - if subblock_no==0: - # in 2nd+ block: in 1st subblock, no change yet - todo - return base_block_out_channels[block_no-1] + if subblock_no == 0: + # in 2nd+ block: the 1st subblock has same channels as in the last block + return base_block_out_channels[block_no - 1] else: - # in 2nd+ block: in 2nd+ subblock, resnet has double channels -> change - todo + # in 2nd+ block: after the 1st subblock, the channels have changed return base_block_out_channels[block_no] output_channel = block_out_channels[0] @@ -223,39 +225,45 @@ def get_extra_channel(block_no, subblock_no): output_channel = block_out_channels[i] use_crossattention = down_block_type == "CrossAttnDownBlock2D" - self.down_subblocks.append(CrossAttnSubBlock2D( - has_crossattn=use_crossattention, - in_channels=input_channel + get_extra_channel(block_no=i, subblock_no=0), - out_channels=output_channel, - temb_channels=time_embedding_dim, - transformer_layers_per_block=transformer_layers_per_block[i], - num_attention_heads=num_attention_heads[i], - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - norm_num_groups=norm_num_groups, - )) - self.down_subblocks.append(CrossAttnSubBlock2D( - has_crossattn=use_crossattention, - in_channels=output_channel + get_extra_channel(block_no=i, subblock_no=1), - out_channels=output_channel, - temb_channels=time_embedding_dim, - transformer_layers_per_block=transformer_layers_per_block[i], - num_attention_heads=num_attention_heads[i], - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - norm_num_groups=norm_num_groups, - )) - if i ctrl ; c2b = ctrl -> base + self.down_zero_convs_b2c = nn.ModuleList([]) + self.down_zero_convs_c2b = nn.ModuleList([]) + self.mid_zero_convs_c2b = nn.ModuleList([]) + self.up_zero_convs_c2b = nn.ModuleList([]) + # 4.1 - Connections from base encoder to ctrl encoder # Information is passed from base to ctrl _before_ each subblock. We therefore use the 'in' channels. # As the information is concatted in ctrl, we don't need to change channel sizes. So channels in = channels out. - for c in base_model_channel_sizes['down - in']: - self.down_zero_convs_in.append(self._make_zero_conv(c, c)) - c = base_model_channel_sizes['mid - out'] - self.down_zero_convs_in.append(self._make_zero_conv(c, c)) + for c in channels_base["down - in"]: + self.down_zero_convs_b2c.append(self._make_zero_conv(c, c)) # 4.2 - Connections from ctrl encoder to base encoder # Information is passed from ctrl to base _after_ each subblock. We therefore use the 'out' channels. # As the information is added to base, the out-channels need to match base. - for i in range(len(self.down_subblocks)): - ch_base_out = base_model_channel_sizes['down - out'][i] - ch_ctrl_out = self.ch_inout_ctrl['down - out'][i] - if i==0: - # for conv_in - self.down_zero_convs_out.append(self._make_zero_conv(self.conv_in.out_channels, ch_base_out)) - self.down_zero_convs_out.append(self._make_zero_conv(ch_ctrl_out, ch_base_out)) + for ch_base, ch_ctrl in zip(channels_base["down - out"], channels_ctrl["down - out"]): + self.down_zero_convs_c2b.append(self._make_zero_conv(ch_ctrl, ch_base)) # 4.3 - Connections in mid block - # todo - better naming? - ch_base_out = base_model_channel_sizes['mid - out'] - ch_ctrl_out = self.ch_inout_ctrl['mid - out'] - self.middle_zero_convs_out = self._make_zero_conv(ch_ctrl_out, ch_base_out) + self.mid_zero_convs_c2b = self._make_zero_conv(channels_ctrl["mid - out"], channels_base["mid - out"]) # 4.3 - Connections from ctrl encoder to base decoder - # todo - skip_channels = reversed([self.conv_in.out_channels] + self.ch_inout_ctrl['down - out']) - for s,i in zip(skip_channels, base_model_channel_sizes['up - in']): - self.up_zero_convs_out.append(self._make_zero_conv(s, i)) + skip_channels = reversed(channels_ctrl["down - out"]) + for s, i in zip(skip_channels, channels_base["up - in"]): + self.up_zero_convs_c2b.append(self._make_zero_conv(s, i)) # 5 - Create conditioning hint embedding self.controlnet_cond_embedding = ControlNetConditioningEmbedding( @@ -319,134 +316,16 @@ def get_extra_channel(block_no, subblock_no): ) def forward(self, *args, **kwargs): - raise ValueError("A ControlNetXSAddonModel cannot be run by itself. Pass it into a ControlNetXSModel model instead.") - - @classmethod - def from_unet( - cls, - unet: UNet2DConditionModel, - conditioning_channels: int = 3, - conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), - controlnet_conditioning_channel_order: str = "rgb", - learn_embedding: bool = False, - time_embedding_mix: float = 1.0, - block_out_channels: Optional[Tuple[int]] = None, - size_ratio: Optional[float] = None, - num_attention_heads: Optional[Union[int, Tuple[int]]] = 8, - norm_num_groups: Optional[int] = None, - ): - # todo - r""" - Instantiate a [`ControlNetXSModel`] from [`UNet2DConditionModel`]. - - Parameters: - unet (`UNet2DConditionModel`): - The UNet model we want to control. The dimensions of the ControlNetXSModel will be adapted to it. - conditioning_channels (`int`, defaults to 3): - Number of channels of conditioning input (e.g. an image) - conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): - The tuple of output channel for each block in the `controlnet_cond_embedding` layer. - controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): - The channel order of conditional image. Will convert to `rgb` if it's `bgr`. - learn_embedding (`bool`, defaults to `False`): - Wether to use time embedding of the control model. If yes, the time embedding is a linear interpolation - of the time embeddings of the control and base model with interpolation parameter - `time_embedding_mix**3`. - time_embedding_mix (`float`, defaults to 1.0): - Linear interpolation parameter used if `learn_embedding` is `True`. - block_out_channels (`Tuple[int]`, *optional*): - Down blocks output channels in control model. Either this or `size_ratio` must be given. - size_ratio (float, *optional*): - When given, block_out_channels is set to a relative fraction of the base model's block_out_channels. - Either this or `block_out_channels` must be given. - num_attention_heads (`Union[int, Tuple[int]]`, *optional*): - The dimension of the attention heads. The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. - norm_num_groups (int, *optional*, defaults to `None`): - The number of groups to use for the normalization of the control unet. If `None`, - `int(unet.config.norm_num_groups * size_ratio)` is taken. - """ - - # Check input - fixed_size = block_out_channels is not None - relative_size = size_ratio is not None - if not (fixed_size ^ relative_size): - raise ValueError( - "Pass exactly one of `block_out_channels` (for absolute sizing) or `control_model_ratio` (for relative sizing)." - ) - - # Create model - if block_out_channels is None: - block_out_channels = [int(size_ratio * c) for c in unet.config.block_out_channels] - - # Check that attention heads and group norms match channel sizes - # - attention heads - def attn_heads_match_channel_sizes(attn_heads, channel_sizes): - if isinstance(attn_heads, (tuple, list)): - return all(c % a == 0 for a, c in zip(attn_heads, channel_sizes)) - else: - return all(c % attn_heads == 0 for c in channel_sizes) - - num_attention_heads = num_attention_heads or unet.config.attention_head_dim - if not attn_heads_match_channel_sizes(num_attention_heads, block_out_channels): - raise ValueError( - f"The dimension of attention heads ({num_attention_heads}) must divide `block_out_channels` ({block_out_channels}). If you didn't set `num_attention_heads` the default settings don't match your model. Set `num_attention_heads` manually." - ) - - # - group norms - def group_norms_match_channel_sizes(num_groups, channel_sizes): - return all(c % num_groups == 0 for c in channel_sizes) - - if norm_num_groups is None: - if group_norms_match_channel_sizes(unet.config.norm_num_groups, block_out_channels): - norm_num_groups = unet.config.norm_num_groups - else: - norm_num_groups = min(block_out_channels) - - if group_norms_match_channel_sizes(norm_num_groups, block_out_channels): - print( - f"`norm_num_groups` was set to `min(block_out_channels)` (={norm_num_groups}) so it divides all block_out_channels` ({block_out_channels}). Set it explicitly to remove this information." - ) - else: - raise ValueError( - f"`block_out_channels` ({block_out_channels}) don't match the base models `norm_num_groups` ({unet.config.norm_num_groups}). Setting `norm_num_groups` to `min(block_out_channels)` ({norm_num_groups}) didn't fix this. Pass `norm_num_groups` explicitly so it divides all block_out_channels." - ) - - def get_time_emb_input_dim(unet: UNet2DConditionModel): - return unet.time_embedding.linear_1.in_features - - def get_time_emb_dim(unet: UNet2DConditionModel): - return unet.time_embedding.linear_2.out_features - - # Clone params from base unet if - # (i) it's required to build SD or SDXL, and - # (ii) it's not used for the time embedding (as time embedding of control model is never used), and - # (iii) it's not set further below anyway - to_keep = [ - "cross_attention_dim", - "down_block_types", - "sample_size", - "transformer_layers_per_block", - "up_block_types", - "upcast_attention", - ] - kwargs = {k: v for k, v in dict(unet.config).items() if k in to_keep} - kwargs.update(block_out_channels=block_out_channels) - kwargs.update(num_attention_heads=num_attention_heads) - kwargs.update(norm_num_groups=norm_num_groups) - - # Add controlnetxs-specific params - kwargs.update( - conditioning_channels=conditioning_channels, - controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, - time_embedding_input_dim=get_time_emb_input_dim(unet), - time_embedding_dim=get_time_emb_dim(unet), - time_embedding_mix=time_embedding_mix, - learn_embedding=learn_embedding, - base_model_channel_sizes=ControlNetXSModel._gather_subblock_sizes(unet, base_or_control="base"), - conditioning_embedding_out_channels=conditioning_embedding_out_channels, + raise ValueError( + "A ControlNetXSAddonModel cannot be run by itself. Pass it into a ControlNetXSModel model instead." ) - return cls(**kwargs) + @torch.no_grad() + def _check_if_vae_compatible(self, vae: AutoencoderKL): + condition_downscale_factor = 2 ** (len(self.config.conditioning_embedding_out_channels) - 1) + vae_downscale_factor = 2 ** (len(vae.config.block_out_channels) - 1) + compatible = condition_downscale_factor == vae_downscale_factor + return compatible, condition_downscale_factor, vae_downscale_factor def _make_zero_conv(self, in_channels, out_channels=None): return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0)) @@ -479,7 +358,7 @@ class ControlNetXSModel(ModelMixin, ConfigMixin): time_embedding_mix (`float`, defaults to 1.0): Linear interpolation parameter used if `learn_embedding` is `True`. A value of 1.0 means only the control model's time embedding will be used. A value of 0.0 means only the base model's time embedding will be used. - base_model_channel_sizes (`Dict[str, List[Tuple[int]]]`): + channels_base (`Dict[str, List[Tuple[int]]]`): Channel sizes of each subblock of base model. Use `gather_subblock_sizes` on your base model to compute it. """ @@ -506,70 +385,23 @@ def get_dim_attn_heads(base_model: UNet2DConditionModel, size_ratio: float, num_ return dim_attn_heads if is_sdxl: - return ControlNetXSModel.from_unet( + time_embedding_mix = 0.95 + controlnet_addon = ControlNetXSAddon.from_unet( base_model, - time_embedding_mix=0.95, - learn_embedding=True, + learn_time_embedding=True, size_ratio=0.1, - conditioning_embedding_out_channels=(16, 32, 96, 256), num_attention_heads=get_dim_attn_heads(base_model, 0.1, 64), ) else: - return ControlNetXSModel.from_unet( + time_embedding_mix = 1.0 + controlnet_addon = ControlNetXSAddon.from_unet( base_model, - time_embedding_mix=1.0, - learn_embedding=True, + learn_time_embedding=True, size_ratio=0.0125, - conditioning_embedding_out_channels=(16, 32, 96, 256), num_attention_heads=get_dim_attn_heads(base_model, 0.0125, 8), ) - @classmethod - def _gather_subblock_sizes(cls, unet: UNet2DConditionModel, base_or_control: str): - """To create correctly sized connections between base and control model, we need to know - the input and output channels of each subblock. - - Parameters: - unet (`UNet2DConditionModel`): - Unet of which the subblock channels sizes are to be gathered. - base_or_control (`str`): - Needs to be either "base" or "control". If "base", decoder is also considered. - """ - if base_or_control not in ["base", "control"]: - raise ValueError("`base_or_control` needs to be either `base` or `control`") - - channel_sizes = {"down": [], "mid": [], "up": []} - - # input convolution - channel_sizes["down"].append((unet.conv_in.in_channels, unet.conv_in.out_channels)) - - # encoder blocks - for module in unet.down_blocks: - if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): - for r in module.resnets: - channel_sizes["down"].append((r.in_channels, r.out_channels)) - if module.downsamplers: - channel_sizes["down"].append( - (module.downsamplers[0].channels, module.downsamplers[0].out_channels) - ) - else: - raise ValueError(f"Encountered unknown module of type {type(module)} while creating ControlNet-XS.") - - # middle block - channel_sizes["mid"].append((unet.mid_block.resnets[0].in_channels, unet.mid_block.resnets[0].out_channels)) - - # decoder blocks - if base_or_control == "base": - for module in unet.up_blocks: - if isinstance(module, (CrossAttnUpBlock2D, UpBlock2D)): - for r in module.resnets: - channel_sizes["up"].append((r.in_channels, r.out_channels)) - else: - raise ValueError( - f"Encountered unknown module of type {type(module)} while creating ControlNet-XS." - ) - - return channel_sizes + return cls(base_model=base_model, ctrl_model=controlnet_addon, time_embedding_mix=time_embedding_mix) @register_to_config def __init__( @@ -583,6 +415,7 @@ def __init__( # 1 - Save options self.use_ctrl_time_embedding = ctrl_model.config.learn_time_embedding self.conditioning_channel_order = ctrl_model.config.conditioning_channel_order + self.class_embed_type = base_model.config.class_embed_type # 2 - Save control model parts self.ctrl_time_embedding = ctrl_model.time_embedding @@ -592,10 +425,10 @@ def __init__( self.ctrl_mid_block = ctrl_model.mid_block # 3 - Save connections - self.down_zero_convs_in = ctrl_model.down_zero_convs_in - self.down_zero_convs_out = ctrl_model.down_zero_convs_out - self.middle_zero_convs_out = ctrl_model.middle_zero_convs_out - self.up_zero_convs_out = ctrl_model.up_zero_convs_out + self.down_zero_convs_b2c = ctrl_model.down_zero_convs_b2c + self.down_zero_convs_c2b = ctrl_model.down_zero_convs_c2b + self.mid_zero_convs_c2b = ctrl_model.mid_zero_convs_c2b + self.up_zero_convs_c2b = ctrl_model.up_zero_convs_c2b # 4 - Save base model parts self.base_time_proj = base_model.time_proj @@ -608,9 +441,9 @@ def __init__( self.base_up_subblocks = nn.ModuleList() # 4.1 - SDXL specific components - if hasattr(base_model, 'add_time_proj'): + if hasattr(base_model, "add_time_proj"): self.base_add_time_proj = base_model.add_time_proj - if hasattr(base_model, 'add_embedding'): + if hasattr(base_model, "add_embedding"): self.base_add_embedding = base_model.add_embedding # 4.2 - Decompose blocks of base model into subblocks @@ -619,10 +452,10 @@ def __init__( resnets = block.resnets attentions = block.attentions if hasattr(block, "attentions") else [None] * len(resnets) for r, a in zip(resnets, attentions): - self.base_down_subblocks.append(CrossAttnSubBlock2D.from_modules(r,a)) + self.base_down_subblocks.append(CrossAttnSubBlock2D.from_modules(r, a)) # Each Downsampler is a subblock if block.downsamplers is not None: - if len(block.downsamplers)!=1: + if len(block.downsamplers) != 1: raise ValueError( "ControlNet-XS currently only supports StableDiffusion and StableDiffusion-XL." "Therefore each down block of the base model should have only 1 downsampler (if any)." @@ -632,7 +465,7 @@ def __init__( for block in base_model.up_blocks: # Each ResNet / Attention / Upsampler triple is a subblock if block.upsamplers is not None: - if len(block.upsamplers)!=1: + if len(block.upsamplers) != 1: raise ValueError( "ControlNet-XS currently only supports StableDiffusion and StableDiffusion-XL." "Therefore each up block of the base model should have only 1 upsampler (if any)." @@ -643,9 +476,9 @@ def __init__( resnets = block.resnets attentions = block.attentions if hasattr(block, "attentions") else [None] * len(resnets) - upsamplers = [None] * (len(resnets)-1) + [upsampler] + upsamplers = [None] * (len(resnets) - 1) + [upsampler] for r, a, u in zip(resnets, attentions, upsamplers): - self.base_up_subblocks.append(CrossAttnUpSubBlock2D.from_modules(r,a,u)) + self.base_up_subblocks.append(CrossAttnUpSubBlock2D.from_modules(r, a, u)) self.base_conv_norm_out = base_model.conv_norm_out self.base_conv_act = base_model.conv_act @@ -709,9 +542,6 @@ def forward( if self.conditioning_channel_order == "bgr": controlnet_cond = torch.flip(controlnet_cond, dims=[1]) - # scale control strength - n_connections = len(self.down_zero_convs_out) + 1 + len(self.up_zero_convs_out) - # prepare attention_mask if attention_mask is not None: attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 @@ -757,10 +587,10 @@ def forward( if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") - if base_model.config.class_embed_type == "timestep": - class_labels = base_model.time_proj(class_labels) + if self.class_embed_type == "timestep": + class_labels = self.base_time_proj(class_labels) - class_emb = base_model.class_embedding(class_labels).to(dtype=self.dtype) + class_emb = self.base_class_embedding(class_labels).to(dtype=self.dtype) temb = temb + class_emb if self.base_addition_embed_type is None: @@ -797,38 +627,49 @@ def forward( hs_base, hs_ctrl = [], [] # Cross Control - # 0 - conv in + # 1 - conv in & down + # The base -> ctrl connections are 'delayed' by 1 subblock, because we want to 'wait' to ensure the new information from the last ctrl -> base connection is also considered + # Therefore, the connections iterate over: + # ctrl -> base: conv_in | subblock 1 | ... | subblock n + # base -> ctrl: | subblock 1 | ... | subblock n | mid block + h_base = self.base_conv_in(h_base) h_ctrl = self.ctrl_conv_in(h_ctrl) if guided_hint is not None: h_ctrl += guided_hint - h_base = h_base + self.down_zero_convs_out[0](h_ctrl) * conditioning_scale # D - add ctrl -> base + h_base = h_base + self.down_zero_convs_c2b[0](h_ctrl) * conditioning_scale # add ctrl -> base hs_base.append(h_base) hs_ctrl.append(h_ctrl) - # 1 - down - for b, c, b2c, c2b in zip(self.base_down_subblocks, self.ctrl_down_subblocks, self.down_zero_convs_in[:-1], self.down_zero_convs_out[1:]): + for b, c, b2c, c2b in zip( + self.base_down_subblocks, + self.ctrl_down_subblocks, + self.down_zero_convs_b2c[:-1], + self.down_zero_convs_c2b[1:], + ): if isinstance(b, CrossAttnSubBlock2D): additional_params = [temb, cemb, attention_mask, cross_attention_kwargs] else: additional_params = [] - h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) # A - concat base -> ctrl - h_base = b(h_base, *additional_params) # B - apply base subblock - h_ctrl = c(h_ctrl, *additional_params) # C - apply ctrl subblock - h_base = h_base + c2b(h_ctrl) * conditioning_scale # D - add ctrl -> base + h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) # concat base -> ctrl + h_base = b(h_base, *additional_params) # apply base subblock + h_ctrl = c(h_ctrl, *additional_params) # apply ctrl subblock + h_base = h_base + c2b(h_ctrl) * conditioning_scale # add ctrl -> base hs_base.append(h_base) hs_ctrl.append(h_ctrl) + h_ctrl = torch.cat([h_ctrl, self.down_zero_convs_b2c[-1](h_base)], dim=1) # concat base -> ctrl # 2 - mid - h_ctrl = torch.cat([h_ctrl, self.down_zero_convs_in[-1](h_base)], dim=1) # A - concat base -> ctrl - h_base = self.base_mid_block(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock - h_ctrl = self.ctrl_mid_block(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock - h_base = h_base + self.middle_zero_convs_out(h_ctrl) * conditioning_scale # D - add ctrl -> base + h_base = self.base_mid_block(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # apply base subblock + h_ctrl = self.ctrl_mid_block(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # apply ctrl subblock + h_base = h_base + self.mid_zero_convs_c2b(h_ctrl) * conditioning_scale # add ctrl -> base # 3 - up - for b, c2b, skip_c, skip_b in zip(self.base_up_subblocks, self.up_zero_convs_out, reversed(hs_ctrl), reversed(hs_base)): + for b, c2b, skip_c, skip_b in zip( + self.base_up_subblocks, self.up_zero_convs_c2b, reversed(hs_ctrl), reversed(hs_base) + ): h_base = h_base + c2b(skip_c) * conditioning_scale # add info from ctrl encoder h_base = torch.cat([h_base, skip_b], dim=1) # concat info from base encoder+ctrl encoder h_base = b(h_base, temb, cemb, attention_mask, cross_attention_kwargs) @@ -857,7 +698,7 @@ def __init__( out_channels: Optional[int] = None, temb_channels: Optional[int] = None, norm_num_groups: Optional[int] = 32, - has_crossattn = False, + has_crossattn=False, transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1, num_attention_heads: Optional[int] = 1, cross_attention_dim: Optional[int] = 1024, @@ -916,6 +757,7 @@ def forward( lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 if self.training and self.gradient_checkpointing: + def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: @@ -990,6 +832,7 @@ def forward( hidden_states: torch.FloatTensor, ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: if self.training and self.gradient_checkpointing: + def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: @@ -1014,7 +857,12 @@ def __init__(self): self.gradient_checkpointing = False @classmethod - def from_modules(cls, resnet: ResnetBlock2D, attention: Optional[Transformer2DModel] = None, upsampler: Optional[Upsample2D] = None): + def from_modules( + cls, + resnet: ResnetBlock2D, + attention: Optional[Transformer2DModel] = None, + upsampler: Optional[Upsample2D] = None, + ): """Create empty subblock and set resnet, attention and upsampler manually""" subblock = cls() subblock.resnet = resnet @@ -1036,6 +884,7 @@ def forward( lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 if self.training and self.gradient_checkpointing: + def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: @@ -1077,4 +926,4 @@ def custom_forward(*inputs): if self.upsampler is not None: hidden_states = self.upsampler(hidden_states) - return hidden_states \ No newline at end of file + return hidden_states diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 2b456f4c3d08..3bf67dfc1cdc 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -128,6 +128,12 @@ "StableDiffusionXLControlNetPipeline", ] ) + _import_structure["controlnet_xs"].extend( + [ + "StableDiffusionControlNetXSPipeline", + "StableDiffusionXLControlNetXSPipeline", + ] + ) _import_structure["deepfloyd_if"] = [ "IFImg2ImgPipeline", "IFImg2ImgSuperResolutionPipeline", @@ -355,6 +361,10 @@ StableDiffusionXLControlNetInpaintPipeline, StableDiffusionXLControlNetPipeline, ) + from .controlnet_xs import ( + StableDiffusionControlNetXSPipeline, + StableDiffusionXLControlNetXSPipeline, + ) from .deepfloyd_if import ( IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, diff --git a/src/diffusers/pipelines/controlnet_xs/__init__.py b/src/diffusers/pipelines/controlnet_xs/__init__.py new file mode 100644 index 000000000000..978278b184f9 --- /dev/null +++ b/src/diffusers/pipelines/controlnet_xs/__init__.py @@ -0,0 +1,68 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_flax_available, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_controlnet_xs"] = ["StableDiffusionControlNetXSPipeline"] + _import_structure["pipeline_controlnet_xs_sd_xl"] = ["StableDiffusionXLControlNetXSPipeline"] +try: + if not (is_transformers_available() and is_flax_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_flax_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects)) +else: + pass # _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline + from .pipeline_controlnet_xs_sd_xl import StableDiffusionXLControlNetXSPipeline + + try: + if not (is_transformers_available() and is_flax_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_flax_and_transformers_objects import * # noqa F403 + else: + pass # from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline + + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py new file mode 100644 index 000000000000..7a34ef526002 --- /dev/null +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -0,0 +1,946 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetXSAddon, ControlNetXSModel, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionControlNetXSPipeline, ControlNetXSModel + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" + >>> negative_prompt = "low quality, bad quality, sketches" + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" + ... ) + + >>> # initialize the models and pipeline + >>> controlnet_conditioning_scale = 0.5 + >>> controlnet = ControlNetXSModel.from_pretrained( + ... "UmerHA/ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16 + ... ) + >>> pipe = StableDiffusionControlNetXSPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-2-1", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> # get canny image + >>> image = np.array(image) + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + >>> # generate image + >>> image = pipe( + ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image + ... ).images[0] + ``` +""" + + +class StableDiffusionControlNetXSPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet-XS guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`ControlNetXSModel`]): + Provides additional conditioning to the `unet` during the denoising process. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae>controlnet" + _optional_components = ["safety_checker", "feature_extractor"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet_addon: ControlNetXSAddon, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + ( + vae_compatible, + cnxs_condition_downsample_factor, + vae_downsample_factor, + ) = controlnet_addon._check_if_vae_compatible(vae) + if not vae_compatible: + raise ValueError( + f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXS model {cnxs_condition_downsample_factor} need to be equal. Consider building the ControlNetXS model with different `conditioning_block_sizes`." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet_addon=controlnet_addon, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.controlnet = ControlNetXSModel(base_model=unet, ctrl_model=controlnet_addon) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetXSModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetXSModel) + ): + self.check_image(image, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetXSModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetXSModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + else: + assert False + + start, end = control_guidance_start, control_guidance_end + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + control_guidance_start: float = 0.0, + control_guidance_end: float = 1.0, + clip_skip: Optional[int] = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`, + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be + accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height + and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in + `init`, images must be passed as a list such that each element of the list can be correctly batched for + input to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare image + if isinstance(controlnet, ControlNetXSModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + height, width = image.shape[-2:] + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + dont_control = ( + i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end + ) + if dont_control: + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=True, + ).sample + else: + noise_pred = self.controlnet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=prompt_embeds, + controlnet_cond=image, + conditioning_scale=controlnet_conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=True, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py new file mode 100644 index 000000000000..5caafc4ee48b --- /dev/null +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -0,0 +1,1120 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers.utils.import_utils import is_invisible_watermark_available + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetXSAddon, ControlNetXSModel, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionXLControlNetXSPipeline, ControlNetXSModel, AutoencoderKL + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" + >>> negative_prompt = "low quality, bad quality, sketches" + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" + ... ) + + >>> # initialize the models and pipeline + >>> controlnet_conditioning_scale = 0.5 # recommended for good generalization + >>> controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SDXL-canny", torch_dtype=torch.float16) + >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) + >>> pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, torch_dtype=torch.float16 + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> # get canny image + >>> image = np.array(image) + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> # generate image + >>> image = pipe( + ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image + ... ).images[0] + ``` +""" + + +class StableDiffusionXLControlNetXSPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet-XS guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]): + Second frozen text-encoder + ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + tokenizer_2 ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`ControlNetXSModel`]: + Provides additional conditioning to the `unet` during the denoising process. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings should always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to + watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no + watermarker is used. + """ + + # leave controlnet out on purpose because it iterates with unet + model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae->controlnet" + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet_addon: ControlNetXSAddon, + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + ): + super().__init__() + + ( + vae_compatible, + cnxs_condition_downsample_factor, + vae_downsample_factor, + ) = controlnet_addon._check_if_vae_compatible(vae) + if not vae_compatible: + raise ValueError( + f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXS model {cnxs_condition_downsample_factor} need to be equal. Consider building the ControlNetXS model with different `conditioning_block_sizes`." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet_addon=controlnet_addon, + scheduler=scheduler, + ) + self.controlnet = ControlNetXSModel(base_model=unet, ctrl_model=controlnet_addon) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + image, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetXSModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetXSModel) + ): + self.check_image(image, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetXSModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetXSModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + else: + assert False + + start, end = control_guidance_start, control_guidance_end + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + control_guidance_start: float = 0.0, + control_guidance_end: float = 1.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`, + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be + accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height + and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in + `init`, images must be passed as a list such that each element of the list can be correctly batched for + input to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` + and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, pooled text embeddings are generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt + weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. + control_guidance_start (`float`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] is + returned, otherwise a `tuple` is returned containing the output images. + """ + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + image, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + prompt_2, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetXSModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + height, width = image.shape[-2:] + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Prepare added time ids & embeddings + if isinstance(image, list): + original_size = original_size or image[0].shape[-2:] + else: + original_size = original_size or image.shape[-2:] + target_size = target_size or (height, width) + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # predict the noise residual + dont_control = ( + i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end + ) + if dont_control: + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=True, + ).sample + else: + noise_pred = self.controlnet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=prompt_embeds, + controlnet_cond=image, + conditioning_scale=controlnet_conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=True, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # manually for max memory savings + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) From 2b73d025814f30f11f579473ad425b63369044ab Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Fri, 12 Jan 2024 16:23:38 +0100 Subject: [PATCH 08/75] Cleaned up connection creation --- src/diffusers/models/controlnet_xs.py | 80 +++++++++++++++++--------- src/diffusers/models/unet_2d_blocks.py | 2 + 2 files changed, 54 insertions(+), 28 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index b1e6c5c2105a..6109554f1e38 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -83,6 +83,46 @@ def forward(self, conditioning): class ControlNetXSAddon(ModelMixin, ConfigMixin): + + @staticmethod + def gather_base_subblock_sizes(blocks_sizes: List[int]): + """todo - comment""" + + n_blocks = len(blocks_sizes) + n_subblocks_per_block = 3 + + down_out = [] + up_in = [] + + # down_out + for b in range(n_blocks): + for i in range(n_subblocks_per_block): + if b==n_blocks-1 and i==2: + # last block has now downsampler, so has only 2 subblocks instead of 3 + continue + if i==0: + # first subblock has same input channels as in last block, + # because channels are changed by the first resnet, which is the first subblock + down_out.append(blocks_sizes[max(b-1,0)]) + else: + down_out.append(blocks_sizes[b]) + down_out.append(blocks_sizes[-1]) + + # up_in + rev_blocks_sizes = list(reversed(blocks_sizes)) + for b in range(len(rev_blocks_sizes)): + for i in range(n_subblocks_per_block): + if i==0: + up_in.append(rev_blocks_sizes[max(b-1,0)]) + else: + up_in.append(rev_blocks_sizes[b]) + + return { + "down - out": down_out, + "mid - out": blocks_sizes[-1], + "up - in": up_in, + } + @classmethod def from_unet( cls, @@ -102,12 +142,7 @@ def from_unet( "Pass exactly one of `block_out_channels` (for absolute sizing) or `control_model_ratio` (for relative sizing)." ) - channels_base = { # todo - "down - in": [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280], - "down - out": [320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280], - "mid - out": 1280, - "up - in": [1280, 1280, 1280, 1280, 1280, 1280, 1280, 640, 640, 640, 320, 320], - } + channels_base = ControlNetXSAddon.gather_base_subblock_sizes(base_model.config.block_out_channels) block_out_channels = [int(b * size_ratio) for b in base_model.config.block_out_channels] if num_attention_heads is None: @@ -122,7 +157,6 @@ def from_unet( addition_time_embed_dim=base_model.config.addition_time_embed_dim, attention_head_dim=num_attention_heads, block_out_channels=block_out_channels, - base_block_out_channels=base_model.config.block_out_channels, cross_attention_dim=base_model.config.cross_attention_dim, down_block_types=base_model.config.down_block_types, projection_class_embeddings_input_dim=base_model.config.projection_class_embeddings_input_dim, @@ -142,7 +176,6 @@ def __init__( time_embedding_dim: int = 1280, learn_time_embedding: bool = False, channels_base: Dict[str, List[Tuple[int]]] = { - "down - in": [320, 320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280], "down - out": [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280], "mid - out": 1280, "up - in": [1280, 1280, 1280, 1280, 1280, 1280, 1280, 640, 640, 640, 320, 320], @@ -151,7 +184,6 @@ def __init__( addition_time_embed_dim=None, attention_head_dim=[4], block_out_channels=[4, 8, 16, 16], - base_block_out_channels=[320, 640, 1280, 1280], cross_attention_dim=1024, down_block_types=["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"], projection_class_embeddings_input_dim=None, @@ -206,20 +238,9 @@ def __init__( transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) # down - def channels_from_base(block_no, subblock_no): - """Determine channel size for extra info from base model""" - if block_no == 0: - # in 1st block: all subblocks have same channels - return base_block_out_channels[0] - else: - if subblock_no == 0: - # in 2nd+ block: the 1st subblock has same channels as in the last block - return base_block_out_channels[block_no - 1] - else: - # in 2nd+ block: after the 1st subblock, the channels have changed - return base_block_out_channels[block_no] - output_channel = block_out_channels[0] + subblock_counter = 0 + for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] @@ -228,7 +249,7 @@ def channels_from_base(block_no, subblock_no): self.down_subblocks.append( CrossAttnSubBlock2D( has_crossattn=use_crossattention, - in_channels=input_channel + channels_from_base(block_no=i, subblock_no=0), + in_channels=input_channel + channels_base['down - out'][subblock_counter], out_channels=output_channel, temb_channels=time_embedding_dim, transformer_layers_per_block=transformer_layers_per_block[i], @@ -238,10 +259,11 @@ def channels_from_base(block_no, subblock_no): norm_num_groups=norm_num_groups, ) ) + subblock_counter += 1 self.down_subblocks.append( CrossAttnSubBlock2D( has_crossattn=use_crossattention, - in_channels=output_channel + channels_from_base(block_no=i, subblock_no=1), + in_channels=output_channel + channels_base['down - out'][subblock_counter], out_channels=output_channel, temb_channels=time_embedding_dim, transformer_layers_per_block=transformer_layers_per_block[i], @@ -251,19 +273,20 @@ def channels_from_base(block_no, subblock_no): norm_num_groups=norm_num_groups, ) ) + subblock_counter += 1 if i < len(down_block_types) - 1: self.down_subblocks.append( DownSubBlock2D( - in_channels=output_channel + channels_from_base(block_no=i, subblock_no=2), + in_channels=output_channel + channels_base['down - out'][subblock_counter], out_channels=output_channel, ) ) + subblock_counter += 1 # mid - channel_from_base = base_block_out_channels[-1] self.mid_block = UNetMidBlock2DCrossAttn( transformer_layers_per_block=transformer_layers_per_block[-1], - in_channels=block_out_channels[-1] + channel_from_base, + in_channels=block_out_channels[-1] + channels_base['down - out'][subblock_counter], out_channels=block_out_channels[-1], temb_channels=time_embedding_dim, resnet_eps=1e-05, @@ -289,9 +312,10 @@ def channels_from_base(block_no, subblock_no): self.up_zero_convs_c2b = nn.ModuleList([]) # 4.1 - Connections from base encoder to ctrl encoder + # todo - better comment # Information is passed from base to ctrl _before_ each subblock. We therefore use the 'in' channels. # As the information is concatted in ctrl, we don't need to change channel sizes. So channels in = channels out. - for c in channels_base["down - in"]: + for c in channels_base["down - out"]: # change down - in to down - out self.down_zero_convs_b2c.append(self._make_zero_conv(c, c)) # 4.2 - Connections from ctrl encoder to base encoder diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 3614fc06f301..2208c98e95bd 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -652,6 +652,8 @@ def __init__( super().__init__() out_channels = out_channels or in_channels + self.in_channels = in_channels + self.out_channels = out_channels self.has_cross_attention = True self.num_attention_heads = num_attention_heads From 93f9d7dedc9e46784cc0e4e58f0a5c1ac8b7b0ba Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Mon, 15 Jan 2024 10:30:48 +0100 Subject: [PATCH 09/75] added debug logs --- src/diffusers/models/attention.py | 8 ++ src/diffusers/models/controlnet_xs.py | 34 +++++- src/diffusers/models/resnet.py | 8 ++ src/diffusers/umer_debug_logger.py | 146 ++++++++++++++++++++++++++ 4 files changed, 194 insertions(+), 2 deletions(-) create mode 100644 src/diffusers/umer_debug_logger.py diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 804c34d617d3..234681e42024 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -17,6 +17,7 @@ import torch.nn.functional as F from torch import nn +from ..umer_debug_logger import udl from ..utils import USE_PEFT_BACKEND from ..utils.torch_utils import maybe_allow_in_graph from .activations import GEGLU, GELU, ApproximateGELU @@ -332,6 +333,8 @@ def forward( attention_mask=attention_mask, **cross_attention_kwargs, ) + udl.log_if("attn1", attn_output, udl.SUBBLOCKM1) + if self.use_ada_layer_norm_zero: attn_output = gate_msa.unsqueeze(1) * attn_output elif self.use_ada_layer_norm_single: @@ -370,6 +373,8 @@ def forward( **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states + udl.log_if("attn2", attn_output, udl.SUBBLOCKM1) + udl.log_if("add attn2", hidden_states, udl.SUBBLOCKM1) # 4. Feed-forward if self.use_ada_layer_norm_continuous: @@ -401,6 +406,9 @@ def forward( if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) + udl.log_if("ff", ff_output, udl.SUBBLOCKM1) + udl.log_if("add ff", hidden_states, udl.SUBBLOCKM1) + return hidden_states diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 6109554f1e38..5d40964fed89 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -7,6 +7,7 @@ from torch import nn from torch.nn import functional as F +from ..umer_debug_logger import udl from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, is_torch_version, logging from .autoencoders import AutoencoderKL @@ -458,7 +459,7 @@ def __init__( self.base_time_proj = base_model.time_proj self.base_time_embedding = base_model.time_embedding self.base_class_embedding = base_model.class_embedding - self.base_addition_embed_type = base_model.addition_embed_type + self.base_addition_embed_type = base_model.config.addition_embed_type self.base_conv_in = base_model.conv_in self.base_down_subblocks = nn.ModuleList() self.base_mid_block = base_model.mid_block @@ -562,6 +563,11 @@ def forward( If `return_dict` is `True`, a [`~models.controlnetxs.ControlNetXSOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ + + udl.log_if('sample', sample, udl.SUBBLOCK) + udl.log_if('timestep', torch.tensor(timestep, dtype=torch.float32), udl.SUBBLOCK) + udl.log_if('encoder_hidden_states', encoder_hidden_states, udl.SUBBLOCK) + # check channel order if self.conditioning_channel_order == "bgr": controlnet_cond = torch.flip(controlnet_cond, dims=[1]) @@ -650,6 +656,9 @@ def forward( h_ctrl = h_base = sample hs_base, hs_ctrl = [], [] + udl.log_if('h_ctrl', h_ctrl, udl.SUBBLOCK) + udl.log_if('h_base', h_base, udl.SUBBLOCK) + # Cross Control # 1 - conv in & down # The base -> ctrl connections are 'delayed' by 1 subblock, because we want to 'wait' to ensure the new information from the last ctrl -> base connection is also considered @@ -662,6 +671,7 @@ def forward( if guided_hint is not None: h_ctrl += guided_hint h_base = h_base + self.down_zero_convs_c2b[0](h_ctrl) * conditioning_scale # add ctrl -> base + udl.log_if('add c2b', h_base, udl.SUBBLOCK) hs_base.append(h_base) hs_ctrl.append(h_ctrl) @@ -678,29 +688,49 @@ def forward( additional_params = [] h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) # concat base -> ctrl + udl.log_if('concat b2c', h_ctrl, udl.SUBBLOCK) + h_base = b(h_base, *additional_params) # apply base subblock + udl.log_if('base', h_base, udl.SUBBLOCK) + h_ctrl = c(h_ctrl, *additional_params) # apply ctrl subblock + udl.log_if('ctrl', h_ctrl, udl.SUBBLOCK) + h_base = h_base + c2b(h_ctrl) * conditioning_scale # add ctrl -> base + udl.log_if('add c2b', h_base, udl.SUBBLOCK) + hs_base.append(h_base) hs_ctrl.append(h_ctrl) - h_ctrl = torch.cat([h_ctrl, self.down_zero_convs_b2c[-1](h_base)], dim=1) # concat base -> ctrl + h_ctrl = torch.cat([h_ctrl, self.down_zero_convs_b2c[-1](h_base)], dim=1) # concat base -> ctrl + udl.log_if('concat b2c', h_ctrl, udl.SUBBLOCK) # 2 - mid h_base = self.base_mid_block(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # apply base subblock + udl.log_if('base', h_base, udl.SUBBLOCK) + h_ctrl = self.ctrl_mid_block(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # apply ctrl subblock + udl.log_if('ctrl', h_ctrl, udl.SUBBLOCK) + h_base = h_base + self.mid_zero_convs_c2b(h_ctrl) * conditioning_scale # add ctrl -> base + udl.log_if('add c2b', h_base, udl.SUBBLOCK) # 3 - up for b, c2b, skip_c, skip_b in zip( self.base_up_subblocks, self.up_zero_convs_c2b, reversed(hs_ctrl), reversed(hs_base) ): h_base = h_base + c2b(skip_c) * conditioning_scale # add info from ctrl encoder + udl.log_if('add c2b', h_base, udl.SUBBLOCK) + h_base = torch.cat([h_base, skip_b], dim=1) # concat info from base encoder+ctrl encoder h_base = b(h_base, temb, cemb, attention_mask, cross_attention_kwargs) + udl.log_if('base', h_base, udl.SUBBLOCK) h_base = self.base_conv_norm_out(h_base) h_base = self.base_conv_act(h_base) h_base = self.base_conv_out(h_base) + udl.log_if('conv_out', h_base, udl.SUBBLOCK) + + udl.stop_if(udl.SUBBLOCK, 'It is done, my dude. Let us look at these tensors.') if not return_dict: return h_base diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index bbfb71ca3fbf..687e32e0fbda 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -20,6 +20,7 @@ import torch.nn as nn import torch.nn.functional as F +from ..umer_debug_logger import udl from ..utils import USE_PEFT_BACKEND from .activations import get_activation from .attention_processor import SpatialNorm @@ -223,6 +224,7 @@ def forward( ) hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states) + udl.log_if("conv1", hidden_states, udl.SUBBLOCKM1) if self.time_emb_proj is not None: if not self.skip_time_act: @@ -233,6 +235,8 @@ def forward( else self.time_emb_proj(temb)[:, :, None, None] ) + udl.log_if("temb", hidden_states, udl.SUBBLOCKM1) + if temb is not None and self.time_embedding_norm == "default": hidden_states = hidden_states + temb @@ -250,6 +254,8 @@ def forward( hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states) + udl.log_if("conv2", hidden_states, udl.SUBBLOCKM1) + if self.conv_shortcut is not None: input_tensor = ( self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor) @@ -257,6 +263,8 @@ def forward( output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + udl.log_if("out", output_tensor, udl.SUBBLOCKM1) + return output_tensor diff --git a/src/diffusers/umer_debug_logger.py b/src/diffusers/umer_debug_logger.py new file mode 100644 index 000000000000..0a7f0349bdeb --- /dev/null +++ b/src/diffusers/umer_debug_logger.py @@ -0,0 +1,146 @@ +# Logger to help me (UmerHA) debug controlnet-xs + +import csv +import inspect +import os +import shutil +from datetime import datetime +from types import SimpleNamespace + +import torch + + +class UmerDebugLogger: + _FILE = "udl.csv" + + BLOCK = 'block' + SUBBLOCK = 'subblock' + SUBBLOCKM1 = 'subblock-minus-1' + allowed_conditions = [BLOCK, SUBBLOCK, SUBBLOCKM1] + + def __init__(self, log_dir="logs", condition=None): + self.log_dir, self.condition, self.tensor_counter = log_dir, condition, 0 + os.makedirs(log_dir, exist_ok=True) + self.fields = ["timestamp", "cls", "fn", "shape", "msg", "condition", "tensor_file"] + self.create_file() + self.warned_of_no_condition = False + print( + "Info: `UmerDebugLogger` created. This is a logging class that will be deleted when the PR to integrate ControlNet-XS is done." + ) + + @property + def full_file_path(self): + return os.path.join(self.log_dir, self._FILE) + + def create_file(self): + file = self.full_file_path + if not os.path.isfile(file): + with open(file, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=self.fields) + writer.writeheader() + + def set_dir(self, log_dir, clear=False): + self.log_dir = log_dir + if clear: + self.clear_logs() + self.create_file() + + def clear_logs(self): + shutil.rmtree(self.log_dir, ignore_errors=True) + os.makedirs(self.log_dir, exist_ok=True) + self.create_file() + + def set_condition(self, condition): + if not isinstance(condition, list): condition = [condition] + self.condition = condition + + def check_condition(self, condition): + if not condition in self.allowed_conditions: raise ValueError(f'Unknown condition: {condition}') + return condition in self.condition + + def log_if(self, msg, t, condition, *, print_=False): + self.maybe_warn_of_no_condition() + + # Use inspect to get the current frame and then go back one level to find caller + frame = inspect.currentframe() + caller_frame = frame.f_back + caller_info = inspect.getframeinfo(caller_frame) + + # Extract class and function name from the caller + cls_name = ( + caller_frame.f_locals.get("self", None).__class__.__name__ if "self" in caller_frame.f_locals else None + ) + function_name = caller_info.function + + if not hasattr(t, "shape"): + t = torch.tensor(t) + t = t.cpu().detach() + + if self.check_condition(condition): + # Save tensor to a file + tensor_filename = f"tensor_{self.tensor_counter}.pt" + torch.save(t, os.path.join(self.log_dir, tensor_filename)) + self.tensor_counter += 1 + + # Log information to CSV + log_info = { + "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "cls": cls_name, + "fn": function_name, + "shape": str(list(t.shape)), + "msg": msg, + "condition": condition, + "tensor_file": tensor_filename, + } + + with open(self.full_file_path, "a", newline="") as f: + writer = csv.DictWriter(f, fieldnames=self.fields) + writer.writerow(log_info) + + if print_: + print(f"{msg}\t{t.flatten()[:10]}") + + def print_if(self, msg, conditions, end="\n"): + self.maybe_warn_of_no_condition() + if not isinstance(conditions, (tuple, list)): + conditions = [conditions] + if any(self.condition == c for c in conditions): + print(msg, end=end) + + def stop_if(self, condition, funny_msg): + if self.check_condition(condition): + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + raise SystemExit(f"{funny_msg} - {current_time}") + + def maybe_warn_of_no_condition(self): + if self.condition is None and not self.warned_of_no_condition: + print("Info: No condition set for UmerDebugLogger") + self.warned_of_no_condition = True + + def get_log_objects(self): + log_objects = [] + file = self.full_file_path + with open(file, newline="") as f: + reader = csv.DictReader(f) + for row in reader: + row["tensor"] = torch.load(os.path.join(self.log_dir, row["tensor_file"])) + row["head"] = row["tensor"].flatten()[:10] + del row["tensor_file"] + log_objects.append(SimpleNamespace(**row)) + return log_objects + + @classmethod + def load_log_objects_from_dir(self, log_dir): + file = os.path.join(log_dir, self._FILE) + log_objects = [] + with open(file, newline="") as f: + reader = csv.DictReader(f) + for row in reader: + row["t"] = torch.load(os.path.join(log_dir, row["tensor_file"])) + row["head"] = row["t"].flatten()[:10] + del row["tensor_file"] + log_objects.append(SimpleNamespace(**row)) + return log_objects + + +udl = UmerDebugLogger() From 36bb8f9917051eed8e814e436d11307430633955 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Mon, 15 Jan 2024 11:19:28 +0100 Subject: [PATCH 10/75] updated logs --- src/diffusers/models/attention.py | 4 +++- src/diffusers/models/resnet.py | 4 +++- src/diffusers/models/transformer_2d.py | 11 +++++++++++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 234681e42024..0e27117ffa88 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -333,7 +333,6 @@ def forward( attention_mask=attention_mask, **cross_attention_kwargs, ) - udl.log_if("attn1", attn_output, udl.SUBBLOCKM1) if self.use_ada_layer_norm_zero: attn_output = gate_msa.unsqueeze(1) * attn_output @@ -344,6 +343,9 @@ def forward( if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) + udl.log_if("attn1", attn_output, udl.SUBBLOCKM1) + udl.log_if("add attn1", hidden_states, udl.SUBBLOCKM1) + # 2.5 GLIGEN Control if gligen_kwargs is not None: hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 687e32e0fbda..f50f1e4a4992 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -235,11 +235,13 @@ def forward( else self.time_emb_proj(temb)[:, :, None, None] ) - udl.log_if("temb", hidden_states, udl.SUBBLOCKM1) + udl.log_if("temb", temb, udl.SUBBLOCKM1) if temb is not None and self.time_embedding_norm == "default": hidden_states = hidden_states + temb + udl.log_if("add temb", hidden_states, udl.SUBBLOCKM1) + if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": hidden_states = self.norm2(hidden_states, temb) else: diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 128395cc161a..0a0019b1b175 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -18,6 +18,7 @@ import torch.nn.functional as F from torch import nn +from umer_debug_logger import udl from ..configuration_utils import ConfigMixin, register_to_config from ..models.embeddings import ImagePositionalEmbeddings from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version @@ -325,6 +326,8 @@ def forward( residual = hidden_states hidden_states = self.norm(hidden_states) + udl.log_if('norm', hidden_states, udl.SUBBLOCKM1) + if not self.use_linear_projection: hidden_states = ( self.proj_in(hidden_states, scale=lora_scale) @@ -342,9 +345,13 @@ def forward( else self.proj_in(hidden_states) ) + udl.log_if('proj_in', hidden_states, udl.SUBBLOCKM1) + elif self.is_input_vectorized: + print('umer: wtf, this happened?') hidden_states = self.latent_image_embedding(hidden_states) elif self.is_input_patches: + print('umer: wtf, why did this happen?') height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size hidden_states = self.pos_embed(hidden_states) @@ -358,6 +365,8 @@ def forward( timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype ) + + # 2. Blocks if self.caption_projection is not None: batch_size = hidden_states.shape[0] @@ -453,6 +462,8 @@ def custom_forward(*inputs): shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) ) + udl.log_if('proj_out', output, udl.SUBBLOCKM1) + if not return_dict: return (output,) From 0878d37266ea9ffdabbe413bb7ba08b4e86beb2b Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Mon, 15 Jan 2024 13:58:43 +0100 Subject: [PATCH 11/75] logs: added input loading --- src/diffusers/models/controlnet_xs.py | 10 +++++---- src/diffusers/umer_debug_logger.py | 31 +++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 5d40964fed89..be0181a52003 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -564,10 +564,6 @@ def forward( tuple is returned where the first element is the sample tensor. """ - udl.log_if('sample', sample, udl.SUBBLOCK) - udl.log_if('timestep', torch.tensor(timestep, dtype=torch.float32), udl.SUBBLOCK) - udl.log_if('encoder_hidden_states', encoder_hidden_states, udl.SUBBLOCK) - # check channel order if self.conditioning_channel_order == "bgr": controlnet_cond = torch.flip(controlnet_cond, dims=[1]) @@ -594,6 +590,12 @@ def forward( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) + sample, timesteps, encoder_hidden_states = udl.do_input_action(x=sample, t=timesteps, xcross=encoder_hidden_states) + + udl.log_if('sample', sample, udl.SUBBLOCK) + udl.log_if('timesteps', timesteps, udl.SUBBLOCK) + udl.log_if('encoder_hidden_states', encoder_hidden_states, udl.SUBBLOCK) + t_emb = self.base_time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors diff --git a/src/diffusers/umer_debug_logger.py b/src/diffusers/umer_debug_logger.py index 0a7f0349bdeb..e4ca2306bb93 100644 --- a/src/diffusers/umer_debug_logger.py +++ b/src/diffusers/umer_debug_logger.py @@ -18,6 +18,8 @@ class UmerDebugLogger: SUBBLOCKM1 = 'subblock-minus-1' allowed_conditions = [BLOCK, SUBBLOCK, SUBBLOCKM1] + input_files = None + def __init__(self, log_dir="logs", condition=None): self.log_dir, self.condition, self.tensor_counter = log_dir, condition, 0 os.makedirs(log_dir, exist_ok=True) @@ -142,5 +144,34 @@ def load_log_objects_from_dir(self, log_dir): log_objects.append(SimpleNamespace(**row)) return log_objects + def save_input(self, dir_, x, t, xcross): + self.input_files = SimpleNamespace( + x=os.path.join(dir_, x), + t=os.path.join(dir_, t), + xcross=os.path.join(dir_, xcross), + ) + self.input_action = 'save' + + def load_input(self, dir_, x, t, xcross): + self.input_files = SimpleNamespace( + x=os.path.join(dir_, x), + t=os.path.join(dir_, t), + xcross=os.path.join(dir_, xcross), + ) + self.input_action = 'save' + + def do_input_action(self, x, t, xcross): + assert self.input_files is not None, "self.input_files not set! Use save_input or load_input" + assert self.input_action in ['save', 'load'] + if self.input_action == 'save': + torch.save(x, os.path.join(self.log_dir, self.input_files.x)) + torch.save(t, os.path.join(self.log_dir, self.input_files.t)) + torch.save(xcross, os.path.join(self.log_dir, self.input_files.xcross)) + else: + x = torch.load(os.path.join(self.log_dir, self.input_files.x)) + t = torch.load(os.path.join(self.log_dir, self.input_files.t)) + xcross = torch.load(os.path.join(self.log_dir, self.input_files.xcross)) + return x, t, xcross + udl = UmerDebugLogger() From 72b29dee8b6fe42563b3dc93bcc5574e281c9f84 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Mon, 15 Jan 2024 15:17:10 +0100 Subject: [PATCH 12/75] Update umer_debug_logger.py --- src/diffusers/umer_debug_logger.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/diffusers/umer_debug_logger.py b/src/diffusers/umer_debug_logger.py index e4ca2306bb93..747663831265 100644 --- a/src/diffusers/umer_debug_logger.py +++ b/src/diffusers/umer_debug_logger.py @@ -164,13 +164,15 @@ def do_input_action(self, x, t, xcross): assert self.input_files is not None, "self.input_files not set! Use save_input or load_input" assert self.input_action in ['save', 'load'] if self.input_action == 'save': - torch.save(x, os.path.join(self.log_dir, self.input_files.x)) - torch.save(t, os.path.join(self.log_dir, self.input_files.t)) - torch.save(xcross, os.path.join(self.log_dir, self.input_files.xcross)) + torch.save(x, self.input_files.x) + torch.save(t, self.input_files.t) + torch.save(xcross, self.input_files.xcross) + print('[udl] Input saved') else: - x = torch.load(os.path.join(self.log_dir, self.input_files.x)) - t = torch.load(os.path.join(self.log_dir, self.input_files.t)) - xcross = torch.load(os.path.join(self.log_dir, self.input_files.xcross)) + x = torch.load(self.input_files.x) + t = torch.load( self.input_files.t) + xcross = torch.load(self.input_files.xcross) + print('[udl] Input loaded') return x, t, xcross From 773721fa55f2a2f69bbb20f0658b411f56f02a05 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Mon, 15 Jan 2024 15:43:45 +0100 Subject: [PATCH 13/75] log: Loading hint --- src/diffusers/models/controlnet_xs.py | 8 +++++++- src/diffusers/models/transformer_2d.py | 2 +- src/diffusers/umer_debug_logger.py | 12 ++++++++---- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index be0181a52003..b709f7705cc6 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -590,11 +590,17 @@ def forward( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) - sample, timesteps, encoder_hidden_states = udl.do_input_action(x=sample, t=timesteps, xcross=encoder_hidden_states) + sample, timesteps, encoder_hidden_states, controlnet_cond = udl.do_input_action( + x=sample, + t=timesteps, + xcross=encoder_hidden_states, + hint=controlnet_cond + ) udl.log_if('sample', sample, udl.SUBBLOCK) udl.log_if('timesteps', timesteps, udl.SUBBLOCK) udl.log_if('encoder_hidden_states', encoder_hidden_states, udl.SUBBLOCK) + udl.log_if('controlnet_cond', controlnet_cond, udl.SUBBLOCK) t_emb = self.base_time_proj(timesteps) diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 0a0019b1b175..62313f2fda38 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -18,7 +18,7 @@ import torch.nn.functional as F from torch import nn -from umer_debug_logger import udl +from ..umer_debug_logger import udl from ..configuration_utils import ConfigMixin, register_to_config from ..models.embeddings import ImagePositionalEmbeddings from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version diff --git a/src/diffusers/umer_debug_logger.py b/src/diffusers/umer_debug_logger.py index 747663831265..a6b6e99473d7 100644 --- a/src/diffusers/umer_debug_logger.py +++ b/src/diffusers/umer_debug_logger.py @@ -144,36 +144,40 @@ def load_log_objects_from_dir(self, log_dir): log_objects.append(SimpleNamespace(**row)) return log_objects - def save_input(self, dir_, x, t, xcross): + def save_input(self, dir_, x, t, xcross, hint): self.input_files = SimpleNamespace( x=os.path.join(dir_, x), t=os.path.join(dir_, t), xcross=os.path.join(dir_, xcross), + hint=os.path.join(dir_,hint) ) self.input_action = 'save' - def load_input(self, dir_, x, t, xcross): + def load_input(self, dir_, x, t, xcross, hint): self.input_files = SimpleNamespace( x=os.path.join(dir_, x), t=os.path.join(dir_, t), xcross=os.path.join(dir_, xcross), + hint=os.path.join(dir_,hint) ) self.input_action = 'save' - def do_input_action(self, x, t, xcross): + def do_input_action(self, x, t, xcross, hint): assert self.input_files is not None, "self.input_files not set! Use save_input or load_input" assert self.input_action in ['save', 'load'] if self.input_action == 'save': torch.save(x, self.input_files.x) torch.save(t, self.input_files.t) torch.save(xcross, self.input_files.xcross) + torch.save(hint, self.input_files.hint) print('[udl] Input saved') else: x = torch.load(self.input_files.x) t = torch.load( self.input_files.t) xcross = torch.load(self.input_files.xcross) + hint = torch.load(self.input_files.hint) print('[udl] Input loaded') - return x, t, xcross + return x, t, xcross, hint udl = UmerDebugLogger() From c6c831a7ff2f67dd35eeeaa3d3a4f351782b58c2 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Mon, 15 Jan 2024 16:06:40 +0100 Subject: [PATCH 14/75] Update umer_debug_logger.py --- src/diffusers/umer_debug_logger.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/umer_debug_logger.py b/src/diffusers/umer_debug_logger.py index a6b6e99473d7..88188942c495 100644 --- a/src/diffusers/umer_debug_logger.py +++ b/src/diffusers/umer_debug_logger.py @@ -160,7 +160,7 @@ def load_input(self, dir_, x, t, xcross, hint): xcross=os.path.join(dir_, xcross), hint=os.path.join(dir_,hint) ) - self.input_action = 'save' + self.input_action = 'load' def do_input_action(self, x, t, xcross, hint): assert self.input_files is not None, "self.input_files not set! Use save_input or load_input" @@ -172,10 +172,10 @@ def do_input_action(self, x, t, xcross, hint): torch.save(hint, self.input_files.hint) print('[udl] Input saved') else: - x = torch.load(self.input_files.x) - t = torch.load( self.input_files.t) - xcross = torch.load(self.input_files.xcross) - hint = torch.load(self.input_files.hint) + x = torch.load(self.input_files.x, map_location=x.device) + t = torch.load( self.input_files.t, map_location=t.device) + xcross = torch.load(self.input_files.xcross, map_location=xcross.device) + hint = torch.load(self.input_files.hint, map_location=hint.device) print('[udl] Input loaded') return x, t, xcross, hint From d20d4bf6042d7849e80155504a17dc0979e2101d Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Tue, 16 Jan 2024 08:54:31 +0100 Subject: [PATCH 15/75] added logs --- src/diffusers/models/controlnet_xs.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index b709f7705cc6..92c567d8bc39 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -675,7 +675,9 @@ def forward( # base -> ctrl: | subblock 1 | ... | subblock n | mid block h_base = self.base_conv_in(h_base) + udl.log_if('base', h_ctrl, udl.SUBBLOCK) h_ctrl = self.ctrl_conv_in(h_ctrl) + udl.log_if('ctrl', h_ctrl, udl.SUBBLOCK) if guided_hint is not None: h_ctrl += guided_hint h_base = h_base + self.down_zero_convs_c2b[0](h_ctrl) * conditioning_scale # add ctrl -> base From 8145d36f097c4589014166c8483228c06f3f2c92 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Tue, 16 Jan 2024 10:03:58 +0100 Subject: [PATCH 16/75] Changed debug logging --- src/diffusers/models/controlnet_xs.py | 2 +- src/diffusers/umer_debug_logger.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 92c567d8bc39..503cafc39db2 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -675,7 +675,7 @@ def forward( # base -> ctrl: | subblock 1 | ... | subblock n | mid block h_base = self.base_conv_in(h_base) - udl.log_if('base', h_ctrl, udl.SUBBLOCK) + udl.log_if('base', h_base, udl.SUBBLOCK) h_ctrl = self.ctrl_conv_in(h_ctrl) udl.log_if('ctrl', h_ctrl, udl.SUBBLOCK) if guided_hint is not None: diff --git a/src/diffusers/umer_debug_logger.py b/src/diffusers/umer_debug_logger.py index 88188942c495..71f2c5d5fae3 100644 --- a/src/diffusers/umer_debug_logger.py +++ b/src/diffusers/umer_debug_logger.py @@ -13,10 +13,11 @@ class UmerDebugLogger: _FILE = "udl.csv" + INPUT_SAVE = 'input_save' BLOCK = 'block' SUBBLOCK = 'subblock' SUBBLOCKM1 = 'subblock-minus-1' - allowed_conditions = [BLOCK, SUBBLOCK, SUBBLOCKM1] + allowed_conditions = [INPUT_SAVE, BLOCK, SUBBLOCK, SUBBLOCKM1] input_files = None @@ -170,13 +171,15 @@ def do_input_action(self, x, t, xcross, hint): torch.save(t, self.input_files.t) torch.save(xcross, self.input_files.xcross) torch.save(hint, self.input_files.hint) - print('[udl] Input saved') + assert x.shape[0]==t.shape[0]==xcross.shape[0]==hint.shape[0] + print(f'[udl] Input saved (batch size = {x.shape[0]})') else: x = torch.load(self.input_files.x, map_location=x.device) t = torch.load( self.input_files.t, map_location=t.device) xcross = torch.load(self.input_files.xcross, map_location=xcross.device) hint = torch.load(self.input_files.hint, map_location=hint.device) - print('[udl] Input loaded') + assert x.shape[0]==t.shape[0]==xcross.shape[0]==hint.shape[0] + print('f[udl] Input loaded (batch size = {x.shape[0]})') return x, t, xcross, hint From d657233f92f43546581ff727b9fe14b1186d88b8 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Tue, 16 Jan 2024 12:41:25 +0100 Subject: [PATCH 17/75] debug: added more logs --- src/diffusers/models/controlnet_xs.py | 2 +- src/diffusers/models/resnet.py | 16 +++++++++++----- src/diffusers/umer_debug_logger.py | 2 +- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 503cafc39db2..d0885ce59169 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -598,7 +598,7 @@ def forward( ) udl.log_if('sample', sample, udl.SUBBLOCK) - udl.log_if('timesteps', timesteps, udl.SUBBLOCK) + udl.log_if('timestep', timesteps, udl.SUBBLOCK) udl.log_if('encoder_hidden_states', encoder_hidden_states, udl.SUBBLOCK) udl.log_if('controlnet_cond', controlnet_cond, udl.SUBBLOCK) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index f50f1e4a4992..1cf9f98bb23e 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -189,12 +189,16 @@ def forward( ) -> torch.FloatTensor: hidden_states = input_tensor + udl.log_if('res: input', hidden_states, udl.SUBBLOCKM1) + if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": hidden_states = self.norm1(hidden_states, temb) else: hidden_states = self.norm1(hidden_states) + udl.log_if('res: norm1', hidden_states, udl.SUBBLOCKM1) hidden_states = self.nonlinearity(hidden_states) + udl.log_if('res: nonlin', hidden_states, udl.SUBBLOCKM1) if self.upsample is not None: # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 @@ -223,8 +227,10 @@ def forward( else self.downsample(hidden_states) ) + udl.log_if('res: updown', hidden_states, udl.SUBBLOCKM1) + hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states) - udl.log_if("conv1", hidden_states, udl.SUBBLOCKM1) + udl.log_if('res: conv1', hidden_states, udl.SUBBLOCKM1) if self.time_emb_proj is not None: if not self.skip_time_act: @@ -235,12 +241,12 @@ def forward( else self.time_emb_proj(temb)[:, :, None, None] ) - udl.log_if("temb", temb, udl.SUBBLOCKM1) + udl.log_if('res: temb', temb, udl.SUBBLOCKM1) if temb is not None and self.time_embedding_norm == "default": hidden_states = hidden_states + temb - udl.log_if("add temb", hidden_states, udl.SUBBLOCKM1) + udl.log_if('res: add temb', hidden_states, udl.SUBBLOCKM1) if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": hidden_states = self.norm2(hidden_states, temb) @@ -256,7 +262,7 @@ def forward( hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states) - udl.log_if("conv2", hidden_states, udl.SUBBLOCKM1) + udl.log_if('res: conv2', hidden_states, udl.SUBBLOCKM1) if self.conv_shortcut is not None: input_tensor = ( @@ -265,7 +271,7 @@ def forward( output_tensor = (input_tensor + hidden_states) / self.output_scale_factor - udl.log_if("out", output_tensor, udl.SUBBLOCKM1) + udl.log_if('res: out', output_tensor, udl.SUBBLOCKM1) return output_tensor diff --git a/src/diffusers/umer_debug_logger.py b/src/diffusers/umer_debug_logger.py index 71f2c5d5fae3..a283c31d2188 100644 --- a/src/diffusers/umer_debug_logger.py +++ b/src/diffusers/umer_debug_logger.py @@ -179,7 +179,7 @@ def do_input_action(self, x, t, xcross, hint): xcross = torch.load(self.input_files.xcross, map_location=xcross.device) hint = torch.load(self.input_files.hint, map_location=hint.device) assert x.shape[0]==t.shape[0]==xcross.shape[0]==hint.shape[0] - print('f[udl] Input loaded (batch size = {x.shape[0]})') + print(f'[udl] Input loaded (batch size = {x.shape[0]})') return x, t, xcross, hint From 1aa848c5e28b74323350e5b5ada7822e4f8a7a59 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Wed, 17 Jan 2024 14:34:30 +0100 Subject: [PATCH 18/75] Fixed num_norm_groups --- src/diffusers/models/controlnet_xs.py | 30 ++++++++++++++++++++------ src/diffusers/models/unet_2d_blocks.py | 7 ++++-- 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index d0885ce59169..97144d55294f 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -191,7 +191,7 @@ def __init__( sample_size=96, transformer_layers_per_block: Union[int, Tuple[int]] = 1, upcast_attention=True, - norm_num_groups=4, + norm_num_groups=32, ): super().__init__() @@ -285,15 +285,19 @@ def __init__( subblock_counter += 1 # mid + mid_in_channels = block_out_channels[-1] + channels_base['down - out'][subblock_counter] + mid_out_channels = block_out_channels[-1] + self.mid_block = UNetMidBlock2DCrossAttn( transformer_layers_per_block=transformer_layers_per_block[-1], - in_channels=block_out_channels[-1] + channels_base['down - out'][subblock_counter], - out_channels=block_out_channels[-1], + in_channels=mid_in_channels, + out_channels=mid_out_channels, temb_channels=time_embedding_dim, resnet_eps=1e-05, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads[-1], - resnet_groups=norm_num_groups, + resnet_groups=find_largest_factor(mid_in_channels, norm_num_groups), + resnet_groups_out=find_largest_factor(mid_out_channels, norm_num_groups), use_linear_projection=True, upcast_attention=upcast_attention, ) @@ -754,6 +758,17 @@ def zero_module(module): return module +def find_largest_factor(number, max_factor): + factor = max_factor + if factor >= number: + return number + while factor != 0: + residual = number % factor + if residual == 0: + return factor + factor -= 1 + + class CrossAttnSubBlock2D(nn.Module): def __init__( self, @@ -772,6 +787,7 @@ def __init__( self.gradient_checkpointing = False if is_empty: + # todo umer: comment return self.in_channels = in_channels @@ -781,7 +797,8 @@ def __init__( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, - groups=norm_num_groups, + groups=find_largest_factor(in_channels, start=norm_num_groups), + groups_out=find_largest_factor(out_channels, start=norm_num_groups), eps=1e-5, ) @@ -794,7 +811,7 @@ def __init__( cross_attention_dim=cross_attention_dim, use_linear_projection=True, upcast_attention=upcast_attention, - norm_num_groups=norm_num_groups, + norm_num_groups=find_largest_factor(out_channels, start=norm_num_groups), ) else: self.attention = None @@ -875,6 +892,7 @@ def __init__( self.gradient_checkpointing = False if is_empty: + # todo umer: comment return self.in_channels = in_channels diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 2208c98e95bd..cf978015b5ff 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -640,6 +640,7 @@ def __init__( resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, + resnet_groups_out: Optional[int] = None, resnet_pre_norm: bool = True, num_attention_heads: int = 1, output_scale_factor: float = 1.0, @@ -658,6 +659,7 @@ def __init__( self.has_cross_attention = True self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + resnet_groups_out = resnet_groups_out or resnet_groups # support for variable transformer layers per block if isinstance(transformer_layers_per_block, int): @@ -671,6 +673,7 @@ def __init__( temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, + groups_out=resnet_groups_out, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, @@ -689,7 +692,7 @@ def __init__( in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, + norm_num_groups=resnet_groups_out, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, attention_type=attention_type, @@ -712,7 +715,7 @@ def __init__( out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, - groups=resnet_groups, + groups=resnet_groups_out, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, From f2e63665e7ecedb4e4cc22fcebeb344fcbac5c51 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Wed, 17 Jan 2024 16:36:05 +0100 Subject: [PATCH 19/75] Debug: Logging all of SDXL input --- src/diffusers/models/controlnet_xs.py | 13 ++-- src/diffusers/umer_debug_logger.py | 106 ++++++++++++++++++-------- 2 files changed, 81 insertions(+), 38 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 97144d55294f..cfec32fb31e5 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -594,12 +594,15 @@ def forward( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) - sample, timesteps, encoder_hidden_states, controlnet_cond = udl.do_input_action( + sample, timesteps, encoder_hidden_states, controlnet_cond, text_embeds, time_ids = udl.do_input_action( x=sample, t=timesteps, xcross=encoder_hidden_states, - hint=controlnet_cond + hint=controlnet_cond, + text_embeds=added_cond_kwargs.get('text_embeds', None), + time_ids=added_cond_kwargs.get('time_ids', None), ) + udl.stop_if(udl.INPUT_SAVE, 'Stopping because I only wanted to save input') udl.log_if('sample', sample, udl.SUBBLOCK) udl.log_if('timestep', timesteps, udl.SUBBLOCK) @@ -797,8 +800,8 @@ def __init__( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, - groups=find_largest_factor(in_channels, start=norm_num_groups), - groups_out=find_largest_factor(out_channels, start=norm_num_groups), + groups=find_largest_factor(in_channels, max_factor=norm_num_groups), + groups_out=find_largest_factor(out_channels, max_factor=norm_num_groups), eps=1e-5, ) @@ -811,7 +814,7 @@ def __init__( cross_attention_dim=cross_attention_dim, use_linear_projection=True, upcast_attention=upcast_attention, - norm_num_groups=find_largest_factor(out_channels, start=norm_num_groups), + norm_num_groups=find_largest_factor(out_channels, max_factor=norm_num_groups), ) else: self.attention = None diff --git a/src/diffusers/umer_debug_logger.py b/src/diffusers/umer_debug_logger.py index a283c31d2188..fe70156a475b 100644 --- a/src/diffusers/umer_debug_logger.py +++ b/src/diffusers/umer_debug_logger.py @@ -64,6 +64,9 @@ def check_condition(self, condition): def log_if(self, msg, t, condition, *, print_=False): self.maybe_warn_of_no_condition() + if not self.check_condition(condition): + return + # Use inspect to get the current frame and then go back one level to find caller frame = inspect.currentframe() caller_frame = frame.f_back @@ -78,30 +81,29 @@ def log_if(self, msg, t, condition, *, print_=False): if not hasattr(t, "shape"): t = torch.tensor(t) t = t.cpu().detach() - - if self.check_condition(condition): - # Save tensor to a file - tensor_filename = f"tensor_{self.tensor_counter}.pt" - torch.save(t, os.path.join(self.log_dir, tensor_filename)) - self.tensor_counter += 1 - - # Log information to CSV - log_info = { - "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - "cls": cls_name, - "fn": function_name, - "shape": str(list(t.shape)), - "msg": msg, - "condition": condition, - "tensor_file": tensor_filename, - } - - with open(self.full_file_path, "a", newline="") as f: - writer = csv.DictWriter(f, fieldnames=self.fields) - writer.writerow(log_info) - - if print_: - print(f"{msg}\t{t.flatten()[:10]}") + + # Save tensor to a file + tensor_filename = f"tensor_{self.tensor_counter}.pt" + torch.save(t, os.path.join(self.log_dir, tensor_filename)) + self.tensor_counter += 1 + + # Log information to CSV + log_info = { + "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "cls": cls_name, + "fn": function_name, + "shape": str(list(t.shape)), + "msg": msg, + "condition": condition, + "tensor_file": tensor_filename, + } + + with open(self.full_file_path, "a", newline="") as f: + writer = csv.DictWriter(f, fieldnames=self.fields) + writer.writerow(log_info) + + if print_: + print(f"{msg}\t{t.flatten()[:10]}") def print_if(self, msg, conditions, end="\n"): self.maybe_warn_of_no_condition() @@ -145,42 +147,80 @@ def load_log_objects_from_dir(self, log_dir): log_objects.append(SimpleNamespace(**row)) return log_objects - def save_input(self, dir_, x, t, xcross, hint): - self.input_files = SimpleNamespace( + def save_input(self, dir_, x, t, xcross, hint, text_embeds=None, time_ids=None): + assert (text_embeds is None and time_ids is None) or (text_embeds is not None and time_ids is not None) + is_sdxl = text_embeds is not None + inputs = dict( x=os.path.join(dir_, x), t=os.path.join(dir_, t), xcross=os.path.join(dir_, xcross), hint=os.path.join(dir_,hint) ) + if is_sdxl: + inputs.update(dict( + text_embeds=os.path.join(dir_, x), + time_ids=os.path.join(dir_, time_ids), + )) + self.input_files = SimpleNamespace(**inputs) self.input_action = 'save' - def load_input(self, dir_, x, t, xcross, hint): - self.input_files = SimpleNamespace( + def load_input(self, dir_, x, t, xcross, hint, text_embeds=None, time_ids=None): + assert (text_embeds is None and time_ids is None) or (text_embeds is not None and time_ids is not None) + is_sdxl = text_embeds is not None + inputs = dict( x=os.path.join(dir_, x), t=os.path.join(dir_, t), xcross=os.path.join(dir_, xcross), hint=os.path.join(dir_,hint) ) + if is_sdxl: + inputs.update(dict( + text_embeds=os.path.join(dir_, x), + time_ids=os.path.join(dir_, time_ids), + )) + self.input_files = SimpleNamespace(**inputs) self.input_action = 'load' - def do_input_action(self, x, t, xcross, hint): - assert self.input_files is not None, "self.input_files not set! Use save_input or load_input" - assert self.input_action in ['save', 'load'] + def dont_process_input(self): + self.input_action = 'none' + self.input_files = {} + + def do_input_action(self, x, t, xcross, hint, text_embeds=None, time_ids=None): + assert self.input_files is not None, "self.input_files not set! Use `save_input`, `load_input` or `dont_process_input`" + assert self.input_action in ['save', 'load', 'none'] + assert (text_embeds is None and time_ids is None) or (text_embeds is not None and time_ids is not None) + is_sdxl = text_embeds is not None + if self.input_action == 'save': torch.save(x, self.input_files.x) torch.save(t, self.input_files.t) torch.save(xcross, self.input_files.xcross) torch.save(hint, self.input_files.hint) + assert x.shape[0]==t.shape[0]==xcross.shape[0]==hint.shape[0] + + if is_sdxl: + torch.save(text_embeds, self.input_files.text_embeds) + torch.save(time_ids, self.input_files.time_ids) + assert x.shape[0]==text_embeds.shape[0]==time_ids.shape[0] + print(f'[udl] Input saved (batch size = {x.shape[0]})') - else: + elif self.input_action == 'load': x = torch.load(self.input_files.x, map_location=x.device) t = torch.load( self.input_files.t, map_location=t.device) xcross = torch.load(self.input_files.xcross, map_location=xcross.device) hint = torch.load(self.input_files.hint, map_location=hint.device) assert x.shape[0]==t.shape[0]==xcross.shape[0]==hint.shape[0] + + if is_sdxl: + text_embeds = torch.load(self.input_files.text_embeds, map_location=text_embeds.device) + time_ids = torch.load(self.input_files.time_ids, map_location=time_ids.device) + assert x.shape[0]==text_embeds.shape[0]==time_ids.shape[0] + print(f'[udl] Input loaded (batch size = {x.shape[0]})') - return x, t, xcross, hint + else: + print(f'[udl] Neither saving nor loading input (batch size = {x.shape[0]})') + return x, t, xcross, hint, text_embeds, time_ids udl = UmerDebugLogger() From 507abfad1a259bfada8296440147caeef59a0df9 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Wed, 17 Jan 2024 21:13:03 +0100 Subject: [PATCH 20/75] Update umer_debug_logger.py --- src/diffusers/umer_debug_logger.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/diffusers/umer_debug_logger.py b/src/diffusers/umer_debug_logger.py index fe70156a475b..9e0f888e3fbe 100644 --- a/src/diffusers/umer_debug_logger.py +++ b/src/diffusers/umer_debug_logger.py @@ -147,7 +147,7 @@ def load_log_objects_from_dir(self, log_dir): log_objects.append(SimpleNamespace(**row)) return log_objects - def save_input(self, dir_, x, t, xcross, hint, text_embeds=None, time_ids=None): + def save_input(self, dir_, x, t, xcross, hint, text_embeds=None, time_ids=None, minimize_bs=True): assert (text_embeds is None and time_ids is None) or (text_embeds is not None and time_ids is not None) is_sdxl = text_embeds is not None inputs = dict( @@ -158,11 +158,12 @@ def save_input(self, dir_, x, t, xcross, hint, text_embeds=None, time_ids=None): ) if is_sdxl: inputs.update(dict( - text_embeds=os.path.join(dir_, x), + text_embeds=os.path.join(dir_, text_embeds), time_ids=os.path.join(dir_, time_ids), )) self.input_files = SimpleNamespace(**inputs) self.input_action = 'save' + self.minimize_bs = minimize_bs def load_input(self, dir_, x, t, xcross, hint, text_embeds=None, time_ids=None): assert (text_embeds is None and time_ids is None) or (text_embeds is not None and time_ids is not None) @@ -175,7 +176,7 @@ def load_input(self, dir_, x, t, xcross, hint, text_embeds=None, time_ids=None): ) if is_sdxl: inputs.update(dict( - text_embeds=os.path.join(dir_, x), + text_embeds=os.path.join(dir_, text_embeds), time_ids=os.path.join(dir_, time_ids), )) self.input_files = SimpleNamespace(**inputs) @@ -192,6 +193,21 @@ def do_input_action(self, x, t, xcross, hint, text_embeds=None, time_ids=None): is_sdxl = text_embeds is not None if self.input_action == 'save': + assert x.shape[0]==t.shape[0]==xcross.shape[0]==hint.shape[0] + if is_sdxl: + assert x.shape[0]==text_embeds.shape[0]==time_ids.shape[0] + + bs = x.shape[0] + if self.minimize_bs and bs > 1: + print(f'[udl] Input has batch size {bs} but reducing to 1 before saving') + x = x[0:1] + t = t[0:1] + xcross = xcross[0:1] + hint = hint[0:1] + if is_sdxl: + text_embeds = text_embeds[0:1] + time_ids = time_ids[0:1] + torch.save(x, self.input_files.x) torch.save(t, self.input_files.t) torch.save(xcross, self.input_files.xcross) From 5641711f90455078f14946e2e350b1c6a3b690dc Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Wed, 17 Jan 2024 22:01:38 +0100 Subject: [PATCH 21/75] debug: updated logs --- src/diffusers/models/controlnet_xs.py | 19 ++++---- src/diffusers/umer_debug_logger.py | 69 ++++++++++++--------------- 2 files changed, 42 insertions(+), 46 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index cfec32fb31e5..499b79cef304 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -594,20 +594,13 @@ def forward( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) - sample, timesteps, encoder_hidden_states, controlnet_cond, text_embeds, time_ids = udl.do_input_action( + sample, timesteps, encoder_hidden_states, controlnet_cond = udl.do_input_action( x=sample, t=timesteps, xcross=encoder_hidden_states, hint=controlnet_cond, - text_embeds=added_cond_kwargs.get('text_embeds', None), - time_ids=added_cond_kwargs.get('time_ids', None), ) - udl.stop_if(udl.INPUT_SAVE, 'Stopping because I only wanted to save input') - udl.log_if('sample', sample, udl.SUBBLOCK) - udl.log_if('timestep', timesteps, udl.SUBBLOCK) - udl.log_if('encoder_hidden_states', encoder_hidden_states, udl.SUBBLOCK) - udl.log_if('controlnet_cond', controlnet_cond, udl.SUBBLOCK) t_emb = self.base_time_proj(timesteps) @@ -656,10 +649,20 @@ def forward( time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(temb.dtype) + + add_embeds = udl.do_input_action_for_do_input_action(add_embeds) + aug_emb = self.base_add_embedding(add_embeds) else: raise NotImplementedError() + udl.stop_if(udl.INPUT_SAVE, 'Stopping because I only wanted to save input') + + udl.log_if('sample', sample, udl.SUBBLOCK) + udl.log_if('timestep', timesteps, udl.SUBBLOCK) + udl.log_if('encoder_hidden_states', encoder_hidden_states, udl.SUBBLOCK) + udl.log_if('controlnet_cond', controlnet_cond, udl.SUBBLOCK) + temb = temb + aug_emb if aug_emb is not None else temb # text embeddings diff --git a/src/diffusers/umer_debug_logger.py b/src/diffusers/umer_debug_logger.py index 9e0f888e3fbe..a7e86e9dd341 100644 --- a/src/diffusers/umer_debug_logger.py +++ b/src/diffusers/umer_debug_logger.py @@ -147,38 +147,28 @@ def load_log_objects_from_dir(self, log_dir): log_objects.append(SimpleNamespace(**row)) return log_objects - def save_input(self, dir_, x, t, xcross, hint, text_embeds=None, time_ids=None, minimize_bs=True): - assert (text_embeds is None and time_ids is None) or (text_embeds is not None and time_ids is not None) - is_sdxl = text_embeds is not None + def save_input(self, dir_, x, t, xcross, hint, add_embeds=None, minimize_bs=True): + is_sdxl = add_embeds is not None inputs = dict( x=os.path.join(dir_, x), t=os.path.join(dir_, t), xcross=os.path.join(dir_, xcross), hint=os.path.join(dir_,hint) ) - if is_sdxl: - inputs.update(dict( - text_embeds=os.path.join(dir_, text_embeds), - time_ids=os.path.join(dir_, time_ids), - )) + if is_sdxl: inputs['add_embeds']=os.path.join(dir_, add_embeds) self.input_files = SimpleNamespace(**inputs) self.input_action = 'save' self.minimize_bs = minimize_bs - def load_input(self, dir_, x, t, xcross, hint, text_embeds=None, time_ids=None): - assert (text_embeds is None and time_ids is None) or (text_embeds is not None and time_ids is not None) - is_sdxl = text_embeds is not None + def load_input(self, dir_, x, t, xcross, hint, add_embeds=None): + is_sdxl = add_embeds is not None inputs = dict( x=os.path.join(dir_, x), t=os.path.join(dir_, t), xcross=os.path.join(dir_, xcross), hint=os.path.join(dir_,hint) ) - if is_sdxl: - inputs.update(dict( - text_embeds=os.path.join(dir_, text_embeds), - time_ids=os.path.join(dir_, time_ids), - )) + if is_sdxl:inputs['add_embeds']=os.path.join(dir_, add_embeds), self.input_files = SimpleNamespace(**inputs) self.input_action = 'load' @@ -186,16 +176,12 @@ def dont_process_input(self): self.input_action = 'none' self.input_files = {} - def do_input_action(self, x, t, xcross, hint, text_embeds=None, time_ids=None): + def do_input_action(self, x, t, xcross, hint): assert self.input_files is not None, "self.input_files not set! Use `save_input`, `load_input` or `dont_process_input`" assert self.input_action in ['save', 'load', 'none'] - assert (text_embeds is None and time_ids is None) or (text_embeds is not None and time_ids is not None) - is_sdxl = text_embeds is not None if self.input_action == 'save': assert x.shape[0]==t.shape[0]==xcross.shape[0]==hint.shape[0] - if is_sdxl: - assert x.shape[0]==text_embeds.shape[0]==time_ids.shape[0] bs = x.shape[0] if self.minimize_bs and bs > 1: @@ -204,39 +190,46 @@ def do_input_action(self, x, t, xcross, hint, text_embeds=None, time_ids=None): t = t[0:1] xcross = xcross[0:1] hint = hint[0:1] - if is_sdxl: - text_embeds = text_embeds[0:1] - time_ids = time_ids[0:1] torch.save(x, self.input_files.x) torch.save(t, self.input_files.t) torch.save(xcross, self.input_files.xcross) torch.save(hint, self.input_files.hint) - assert x.shape[0]==t.shape[0]==xcross.shape[0]==hint.shape[0] - - if is_sdxl: - torch.save(text_embeds, self.input_files.text_embeds) - torch.save(time_ids, self.input_files.time_ids) - assert x.shape[0]==text_embeds.shape[0]==time_ids.shape[0] - print(f'[udl] Input saved (batch size = {x.shape[0]})') + elif self.input_action == 'load': x = torch.load(self.input_files.x, map_location=x.device) t = torch.load( self.input_files.t, map_location=t.device) xcross = torch.load(self.input_files.xcross, map_location=xcross.device) hint = torch.load(self.input_files.hint, map_location=hint.device) - assert x.shape[0]==t.shape[0]==xcross.shape[0]==hint.shape[0] - - if is_sdxl: - text_embeds = torch.load(self.input_files.text_embeds, map_location=text_embeds.device) - time_ids = torch.load(self.input_files.time_ids, map_location=time_ids.device) - assert x.shape[0]==text_embeds.shape[0]==time_ids.shape[0] + assert x.shape[0]==t.shape[0]==xcross.shape[0]==hint.shape[0] + print(f'[udl] Input loaded (batch size = {x.shape[0]})') else: print(f'[udl] Neither saving nor loading input (batch size = {x.shape[0]})') - return x, t, xcross, hint, text_embeds, time_ids + return x, t, xcross, hint + + def do_input_action_for_do_input_action(self, add_embeds): + assert self.input_files is not None, "self.input_files not set! Use `save_input`, `load_input` or `dont_process_input`" + assert self.input_action in ['save', 'load', 'none'] + + if self.input_action == 'save': + bs = add_embeds.shape[0] + if self.minimize_bs and bs > 1: + print(f'[udl] Input `add_embeds` has batch size {bs} but reducing to 1 before saving `add_embeds`') + add_embeds = add_embeds[0:1] + torch.save(add_embeds, self.input_files.add_embeds) + print(f'[udl] Input `add_embeds` saved (batch size = {add_embeds.shape[0]})') + + elif self.input_action == 'load': + add_embeds = torch.load(self.input_files.add_embeds, map_location=add_embeds.device) + print(f'[udl] Input loaded (batch size = {add_embeds.shape[0]})') + + else: + print(f'[udl] Neither saving nor loading input (batch size = {add_embeds.shape[0]})') + return add_embeds udl = UmerDebugLogger() From 40393380ca64425fcf3768469c88c8d290c2a74c Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Thu, 18 Jan 2024 13:49:24 +0100 Subject: [PATCH 22/75] checkim --- src/diffusers/models/attention.py | 12 +-- src/diffusers/models/controlnet_xs.py | 138 ++++++++++++++++--------- src/diffusers/models/resnet.py | 18 ++-- src/diffusers/models/transformer_2d.py | 10 +- src/diffusers/umer_debug_logger.py | 2 +- 5 files changed, 111 insertions(+), 69 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 0e27117ffa88..59093659105e 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -343,8 +343,8 @@ def forward( if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) - udl.log_if("attn1", attn_output, udl.SUBBLOCKM1) - udl.log_if("add attn1", hidden_states, udl.SUBBLOCKM1) + udl.log_if("attn: attn1", attn_output, udl.SUBBLOCKM1) + udl.log_if("attn: add attn1", hidden_states, udl.SUBBLOCKM1) # 2.5 GLIGEN Control if gligen_kwargs is not None: @@ -375,8 +375,8 @@ def forward( **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states - udl.log_if("attn2", attn_output, udl.SUBBLOCKM1) - udl.log_if("add attn2", hidden_states, udl.SUBBLOCKM1) + udl.log_if("attn: attn2", attn_output, udl.SUBBLOCKM1) + udl.log_if("attn: add attn2", hidden_states, udl.SUBBLOCKM1) # 4. Feed-forward if self.use_ada_layer_norm_continuous: @@ -408,8 +408,8 @@ def forward( if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) - udl.log_if("ff", ff_output, udl.SUBBLOCKM1) - udl.log_if("add ff", hidden_states, udl.SUBBLOCKM1) + udl.log_if("attn: ff", ff_output, udl.SUBBLOCKM1) + udl.log_if("attn: add ff", hidden_states, udl.SUBBLOCKM1) return hidden_states diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 499b79cef304..dd8be0ad00cb 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -84,10 +84,48 @@ def forward(self, conditioning): class ControlNetXSAddon(ModelMixin, ConfigMixin): + r""" + A ControlNetXSAddon model + + # todo - the below comment is very outdated. update it + + This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic + methods implemented for all models (such as downloading or saving). + + Most of parameters for this model are passed into the [`UNet2DConditionModel`] it creates. Check the documentation + of [`UNet2DConditionModel`] for them. + + Parameters: + conditioning_channels (`int`, defaults to 3): + Number of channels of conditioning input (e.g. an image) + controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `controlnet_cond_embedding` layer. + time_embedding_input_dim (`int`, defaults to 320): + Dimension of input into time embedding. Needs to be same as in the base model. + time_embedding_dim (`int`, defaults to 1280): + Dimension of output from time embedding. Needs to be same as in the base model. + learn_embedding (`bool`, defaults to `False`): + Whether to use time embedding of the control model. If yes, the time embedding is a linear interpolation of + the time embeddings of the control and base model with interpolation parameter `time_embedding_mix**3`. + time_embedding_mix (`float`, defaults to 1.0): + Linear interpolation parameter used if `learn_embedding` is `True`. A value of 1.0 means only the + control model's time embedding will be used. A value of 0.0 means only the base model's time embedding will be used. + base_model_channel_sizes (`Dict[str, List[Tuple[int]]]`): + Channel sizes of each subblock of base model. Use `gather_subblock_sizes` on your base model to compute it. + """ @staticmethod def gather_base_subblock_sizes(blocks_sizes: List[int]): - """todo - comment""" + """ + To create a correctly sized `ControlNetXSAddon`, we need to know + the channels sizes of each base subblock. + + Parameters: + blocks_sizes (`List[int]`): + Channel sizes of each base block. + """ n_blocks = len(blocks_sizes) n_subblocks_per_block = 3 @@ -99,14 +137,17 @@ def gather_base_subblock_sizes(blocks_sizes: List[int]): for b in range(n_blocks): for i in range(n_subblocks_per_block): if b==n_blocks-1 and i==2: - # last block has now downsampler, so has only 2 subblocks instead of 3 + # Last block has no downsampler, so there are only 2 subblocks instead of 3 continue + + # The input channels are changed by the first resnet, which is in the first subblock. if i==0: - # first subblock has same input channels as in last block, - # because channels are changed by the first resnet, which is the first subblock + # Same input channels down_out.append(blocks_sizes[max(b-1,0)]) else: + # Changed input channels down_out.append(blocks_sizes[b]) + down_out.append(blocks_sizes[-1]) # up_in @@ -114,7 +155,7 @@ def gather_base_subblock_sizes(blocks_sizes: List[int]): for b in range(len(rev_blocks_sizes)): for i in range(n_subblocks_per_block): if i==0: - up_in.append(rev_blocks_sizes[max(b-1,0)]) + up_in.append(rev_blocks_sizes[max(b-1,0)]) # todo: explain max(b-1,0) else: up_in.append(rev_blocks_sizes[b]) @@ -133,7 +174,22 @@ def from_unet( num_attention_heads: Optional[List[int]] = None, learn_time_embedding: bool = False, ): - # todo - comment + r""" + Instantiate a [`ControlNetXSAddon`] from [`UNet2DConditionModel`]. + + Parameters: + base_model (`UNet2DConditionModel`): + The UNet model we want to control. The dimensions of the ControlNetXSModel will be adapted to it. + size_ratio (float, *optional*): + When given, block_out_channels is set to a relative fraction of the base model's block_out_channels. + Either this or `block_out_channels` must be given. + block_out_channels (`Tuple[int]`, *optional*): + Down blocks output channels in control model. Either this or `size_ratio` must be given. + num_attention_heads (`Union[int, Tuple[int]]`, *optional*): + The dimension of the attention heads. The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. + learn_time_embedding (`bool`): + todo + """ # Check input fixed_size = block_out_channels is not None @@ -658,10 +714,10 @@ def forward( udl.stop_if(udl.INPUT_SAVE, 'Stopping because I only wanted to save input') - udl.log_if('sample', sample, udl.SUBBLOCK) - udl.log_if('timestep', timesteps, udl.SUBBLOCK) - udl.log_if('encoder_hidden_states', encoder_hidden_states, udl.SUBBLOCK) - udl.log_if('controlnet_cond', controlnet_cond, udl.SUBBLOCK) + udl.log_if("sample", sample, udl.SUBBLOCK) + udl.log_if("timestep", timesteps, udl.SUBBLOCK) + udl.log_if("encoder_hidden_states", encoder_hidden_states, udl.SUBBLOCK) + udl.log_if("controlnet_cond", controlnet_cond, udl.SUBBLOCK) temb = temb + aug_emb if aug_emb is not None else temb @@ -674,24 +730,24 @@ def forward( h_ctrl = h_base = sample hs_base, hs_ctrl = [], [] - udl.log_if('h_ctrl', h_ctrl, udl.SUBBLOCK) - udl.log_if('h_base', h_base, udl.SUBBLOCK) + udl.log_if("h_ctrl", h_ctrl, udl.SUBBLOCK) + udl.log_if("h_base", h_base, udl.SUBBLOCK) # Cross Control # 1 - conv in & down - # The base -> ctrl connections are 'delayed' by 1 subblock, because we want to 'wait' to ensure the new information from the last ctrl -> base connection is also considered + # The base -> ctrl connections are "delayed" by 1 subblock, because we want to "wait" to ensure the new information from the last ctrl -> base connection is also considered # Therefore, the connections iterate over: # ctrl -> base: conv_in | subblock 1 | ... | subblock n # base -> ctrl: | subblock 1 | ... | subblock n | mid block h_base = self.base_conv_in(h_base) - udl.log_if('base', h_base, udl.SUBBLOCK) + udl.log_if("base", h_base, udl.SUBBLOCK) h_ctrl = self.ctrl_conv_in(h_ctrl) - udl.log_if('ctrl', h_ctrl, udl.SUBBLOCK) + udl.log_if("ctrl", h_ctrl, udl.SUBBLOCK) if guided_hint is not None: h_ctrl += guided_hint h_base = h_base + self.down_zero_convs_c2b[0](h_ctrl) * conditioning_scale # add ctrl -> base - udl.log_if('add c2b', h_base, udl.SUBBLOCK) + udl.log_if("add c2b", h_base, udl.SUBBLOCK) hs_base.append(h_base) hs_ctrl.append(h_ctrl) @@ -708,47 +764,47 @@ def forward( additional_params = [] h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) # concat base -> ctrl - udl.log_if('concat b2c', h_ctrl, udl.SUBBLOCK) + udl.log_if("concat b2c", h_ctrl, udl.SUBBLOCK) h_base = b(h_base, *additional_params) # apply base subblock - udl.log_if('base', h_base, udl.SUBBLOCK) + udl.log_if("base", h_base, udl.SUBBLOCK) h_ctrl = c(h_ctrl, *additional_params) # apply ctrl subblock - udl.log_if('ctrl', h_ctrl, udl.SUBBLOCK) + udl.log_if("ctrl", h_ctrl, udl.SUBBLOCK) h_base = h_base + c2b(h_ctrl) * conditioning_scale # add ctrl -> base - udl.log_if('add c2b', h_base, udl.SUBBLOCK) + udl.log_if("add c2b", h_base, udl.SUBBLOCK) hs_base.append(h_base) hs_ctrl.append(h_ctrl) h_ctrl = torch.cat([h_ctrl, self.down_zero_convs_b2c[-1](h_base)], dim=1) # concat base -> ctrl - udl.log_if('concat b2c', h_ctrl, udl.SUBBLOCK) + udl.log_if("concat b2c", h_ctrl, udl.SUBBLOCK) # 2 - mid h_base = self.base_mid_block(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # apply base subblock - udl.log_if('base', h_base, udl.SUBBLOCK) + udl.log_if("base", h_base, udl.SUBBLOCK) h_ctrl = self.ctrl_mid_block(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # apply ctrl subblock - udl.log_if('ctrl', h_ctrl, udl.SUBBLOCK) + udl.log_if("ctrl", h_ctrl, udl.SUBBLOCK) h_base = h_base + self.mid_zero_convs_c2b(h_ctrl) * conditioning_scale # add ctrl -> base - udl.log_if('add c2b', h_base, udl.SUBBLOCK) + udl.log_if("add c2b", h_base, udl.SUBBLOCK) # 3 - up for b, c2b, skip_c, skip_b in zip( self.base_up_subblocks, self.up_zero_convs_c2b, reversed(hs_ctrl), reversed(hs_base) ): h_base = h_base + c2b(skip_c) * conditioning_scale # add info from ctrl encoder - udl.log_if('add c2b', h_base, udl.SUBBLOCK) + udl.log_if("add c2b", h_base, udl.SUBBLOCK) h_base = torch.cat([h_base, skip_b], dim=1) # concat info from base encoder+ctrl encoder h_base = b(h_base, temb, cemb, attention_mask, cross_attention_kwargs) - udl.log_if('base', h_base, udl.SUBBLOCK) + udl.log_if("base", h_base, udl.SUBBLOCK) h_base = self.base_conv_norm_out(h_base) h_base = self.base_conv_act(h_base) h_base = self.base_conv_out(h_base) - udl.log_if('conv_out', h_base, udl.SUBBLOCK) + udl.log_if("conv_out", h_base, udl.SUBBLOCK) udl.stop_if(udl.SUBBLOCK, 'It is done, my dude. Let us look at these tensors.') @@ -793,7 +849,7 @@ def __init__( self.gradient_checkpointing = False if is_empty: - # todo umer: comment + # modules will be set manually, see `CrossAttnSubBlock2D.from_modules` return self.in_channels = in_channels @@ -898,12 +954,11 @@ def __init__( self.gradient_checkpointing = False if is_empty: - # todo umer: comment + # downsampler will be set manually, see `DownSubBlock2D.from_modules` return self.in_channels = in_channels self.out_channels = out_channels - self.downsampler = Downsample2D(in_channels, use_conv=True, out_channels=out_channels, name="op") @classmethod @@ -919,28 +974,15 @@ def forward( self, hidden_states: torch.FloatTensor, ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - # todo: gradient ckptin? - hidden_states = self.downsampler(hidden_states) - else: - hidden_states = self.downsampler(hidden_states) - - return hidden_states + return self.downsampler(hidden_states) class CrossAttnUpSubBlock2D(nn.Module): def __init__(self): - """todo doc - init emtpty as only from_modules will be used""" + """ + In the context of ControlNet-XS, `CrossAttnUpSubBlock2D` are only loaded from existing modules, and not created from scratch. + Therefore, `__init__` is left almost empty. + """ super().__init__() self.gradient_checkpointing = False diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 1cf9f98bb23e..a8be1c224cf7 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -189,16 +189,16 @@ def forward( ) -> torch.FloatTensor: hidden_states = input_tensor - udl.log_if('res: input', hidden_states, udl.SUBBLOCKM1) + udl.log_if("res: input", hidden_states, udl.SUBBLOCKM1) if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": hidden_states = self.norm1(hidden_states, temb) else: hidden_states = self.norm1(hidden_states) - udl.log_if('res: norm1', hidden_states, udl.SUBBLOCKM1) + udl.log_if("res: norm1", hidden_states, udl.SUBBLOCKM1) hidden_states = self.nonlinearity(hidden_states) - udl.log_if('res: nonlin', hidden_states, udl.SUBBLOCKM1) + udl.log_if("res: nonlin", hidden_states, udl.SUBBLOCKM1) if self.upsample is not None: # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 @@ -227,10 +227,10 @@ def forward( else self.downsample(hidden_states) ) - udl.log_if('res: updown', hidden_states, udl.SUBBLOCKM1) + udl.log_if("res: updown", hidden_states, udl.SUBBLOCKM1) hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states) - udl.log_if('res: conv1', hidden_states, udl.SUBBLOCKM1) + udl.log_if("res: conv1", hidden_states, udl.SUBBLOCKM1) if self.time_emb_proj is not None: if not self.skip_time_act: @@ -241,12 +241,12 @@ def forward( else self.time_emb_proj(temb)[:, :, None, None] ) - udl.log_if('res: temb', temb, udl.SUBBLOCKM1) + udl.log_if("res: temb", temb, udl.SUBBLOCKM1) if temb is not None and self.time_embedding_norm == "default": hidden_states = hidden_states + temb - udl.log_if('res: add temb', hidden_states, udl.SUBBLOCKM1) + udl.log_if("res: add temb", hidden_states, udl.SUBBLOCKM1) if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": hidden_states = self.norm2(hidden_states, temb) @@ -262,7 +262,7 @@ def forward( hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states) - udl.log_if('res: conv2', hidden_states, udl.SUBBLOCKM1) + udl.log_if("res: conv2", hidden_states, udl.SUBBLOCKM1) if self.conv_shortcut is not None: input_tensor = ( @@ -271,7 +271,7 @@ def forward( output_tensor = (input_tensor + hidden_states) / self.output_scale_factor - udl.log_if('res: out', output_tensor, udl.SUBBLOCKM1) + udl.log_if("res: out", output_tensor, udl.SUBBLOCKM1) return output_tensor diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 62313f2fda38..bf7942d993f8 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -326,7 +326,7 @@ def forward( residual = hidden_states hidden_states = self.norm(hidden_states) - udl.log_if('norm', hidden_states, udl.SUBBLOCKM1) + udl.log_if("attn: norm", hidden_states, udl.SUBBLOCKM1) if not self.use_linear_projection: hidden_states = ( @@ -345,13 +345,13 @@ def forward( else self.proj_in(hidden_states) ) - udl.log_if('proj_in', hidden_states, udl.SUBBLOCKM1) + udl.log_if("attn: proj_in", hidden_states, udl.SUBBLOCKM1) elif self.is_input_vectorized: - print('umer: wtf, this happened?') + print("umer: wtf, this happened?") hidden_states = self.latent_image_embedding(hidden_states) elif self.is_input_patches: - print('umer: wtf, why did this happen?') + print("umer: wtf, why did this happen?") height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size hidden_states = self.pos_embed(hidden_states) @@ -462,7 +462,7 @@ def custom_forward(*inputs): shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) ) - udl.log_if('proj_out', output, udl.SUBBLOCKM1) + udl.log_if("attn: proj_out", output, udl.SUBBLOCKM1) if not return_dict: return (output,) diff --git a/src/diffusers/umer_debug_logger.py b/src/diffusers/umer_debug_logger.py index a7e86e9dd341..f9cddc831d6d 100644 --- a/src/diffusers/umer_debug_logger.py +++ b/src/diffusers/umer_debug_logger.py @@ -168,7 +168,7 @@ def load_input(self, dir_, x, t, xcross, hint, add_embeds=None): xcross=os.path.join(dir_, xcross), hint=os.path.join(dir_,hint) ) - if is_sdxl:inputs['add_embeds']=os.path.join(dir_, add_embeds), + if is_sdxl:inputs['add_embeds']=os.path.join(dir_, add_embeds) self.input_files = SimpleNamespace(**inputs) self.input_action = 'load' From d7c8d43c91f27aaffe368e25cf1ee8b5ba90cc6c Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Thu, 18 Jan 2024 17:26:50 +0100 Subject: [PATCH 23/75] Readded tests --- src/diffusers/models/controlnet_xs.py | 67 ++-- .../controlnet_xs/pipeline_controlnet_xs.py | 5 +- .../pipeline_controlnet_xs_sd_xl.py | 2 +- tests/pipelines/controlnet_xs/__init__.py | 0 .../controlnet_xs/test_controlnetxs.py | 299 +++++++++++++++ .../controlnet_xs/test_controlnetxs_sdxl.py | 361 ++++++++++++++++++ 6 files changed, 709 insertions(+), 25 deletions(-) create mode 100644 tests/pipelines/controlnet_xs/__init__.py create mode 100644 tests/pipelines/controlnet_xs/test_controlnetxs.py create mode 100644 tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index dd8be0ad00cb..c7d9426abd93 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -85,20 +85,16 @@ def forward(self, conditioning): class ControlNetXSAddon(ModelMixin, ConfigMixin): r""" - A ControlNetXSAddon model - - # todo - the below comment is very outdated. update it - + A `ControlNetXSAddon` model. To use it, pass it into a `ControlNetXSModel` (together with a `UNet2DConditionModel` base model). + This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). - Most of parameters for this model are passed into the [`UNet2DConditionModel`] it creates. Check the documentation - of [`UNet2DConditionModel`] for them. Parameters: conditioning_channels (`int`, defaults to 3): Number of channels of conditioning input (e.g. an image) - controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): + conditioning_channel_order (`str`, defaults to `"rgb"`): The channel order of conditional image. Will convert to `rgb` if it's `bgr`. conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): The tuple of output channel for each block in the `controlnet_cond_embedding` layer. @@ -106,14 +102,34 @@ class ControlNetXSAddon(ModelMixin, ConfigMixin): Dimension of input into time embedding. Needs to be same as in the base model. time_embedding_dim (`int`, defaults to 1280): Dimension of output from time embedding. Needs to be same as in the base model. - learn_embedding (`bool`, defaults to `False`): - Whether to use time embedding of the control model. If yes, the time embedding is a linear interpolation of - the time embeddings of the control and base model with interpolation parameter `time_embedding_mix**3`. - time_embedding_mix (`float`, defaults to 1.0): - Linear interpolation parameter used if `learn_embedding` is `True`. A value of 1.0 means only the - control model's time embedding will be used. A value of 0.0 means only the base model's time embedding will be used. - base_model_channel_sizes (`Dict[str, List[Tuple[int]]]`): - Channel sizes of each subblock of base model. Use `gather_subblock_sizes` on your base model to compute it. + learn_time_embedding (`bool`, defaults to `False`): todo + Whether the time embedding should be learned or fixed. + channels_base (`Dict[str, List[Tuple[int]]]`): todo + Base channel configurations for the model's layers. + addition_embed_type (defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text_time". + addition_time_embed_dim (defaults to `None`): + Dimension for the timestep embeddings. + attention_head_dim (`list[int]`, defaults to `[4]`): + The dimension of the attention heads. + block_out_channels (`list[int]`, defaults to `[4, 8, 16, 16]`): + The tuple of output channels for each block. + cross_attention_dim (`int`, defaults to 1024): + The dimension of the cross attention features. + down_block_types (`list[str]`, defaults to `["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"]`): + The tuple of downsample blocks to use. + projection_class_embeddings_input_dim (defaults to `None`): + The dimension of the `class_labels` input when + sample_size (`int`, defaults to 96): + Height and width of input/output sample. + transformer_layers_per_block (`Union[int, Tuple[int]]`, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + upcast_attention (`bool`, defaults to `True`): + todo + norm_num_groups (`int`, defaults to 32): + If `None`, normalization and activation layers is skipped in post-processing. # todo: is actually max_norm_num_groups """ @staticmethod @@ -154,9 +170,12 @@ def gather_base_subblock_sizes(blocks_sizes: List[int]): rev_blocks_sizes = list(reversed(blocks_sizes)) for b in range(len(rev_blocks_sizes)): for i in range(n_subblocks_per_block): + # The input channels are changed by the first resnet, which is in the first subblock. if i==0: - up_in.append(rev_blocks_sizes[max(b-1,0)]) # todo: explain max(b-1,0) + # Same input channels + up_in.append(rev_blocks_sizes[max(b-1,0)]) else: + # Changed input channels up_in.append(rev_blocks_sizes[b]) return { @@ -173,6 +192,7 @@ def from_unet( block_out_channels: Optional[List[int]] = None, num_attention_heads: Optional[List[int]] = None, learn_time_embedding: bool = False, + conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), ): r""" Instantiate a [`ControlNetXSAddon`] from [`UNet2DConditionModel`]. @@ -180,15 +200,17 @@ def from_unet( Parameters: base_model (`UNet2DConditionModel`): The UNet model we want to control. The dimensions of the ControlNetXSModel will be adapted to it. - size_ratio (float, *optional*): + size_ratio (float, *optional*, defaults to `None`): When given, block_out_channels is set to a relative fraction of the base model's block_out_channels. Either this or `block_out_channels` must be given. - block_out_channels (`Tuple[int]`, *optional*): + block_out_channels (`Tuple[int]`, *optional*, defaults to `None`): Down blocks output channels in control model. Either this or `size_ratio` must be given. - num_attention_heads (`Union[int, Tuple[int]]`, *optional*): + num_attention_heads (`Union[int, Tuple[int]]`, *optional*, defaults to `None`): The dimension of the attention heads. The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. - learn_time_embedding (`bool`): - todo + learn_time_embedding (`bool`, defaults to `False`): + Whether the `ControlNetXSAddon` should learn a time embedding. + conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `controlnet_cond_embedding` layer. """ # Check input @@ -221,13 +243,14 @@ def from_unet( transformer_layers_per_block=base_model.config.transformer_layers_per_block, upcast_attention=base_model.config.upcast_attention, norm_num_groups=norm_num_groups, + conditioning_embedding_out_channels=conditioning_embedding_out_channels, ) @register_to_config def __init__( self, - conditioning_channel_order: str = "rgb", conditioning_channels: int = 3, + conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), time_embedding_input_dim: int = 320, time_embedding_dim: int = 1280, diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index 7a34ef526002..f51ef5df721f 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -91,6 +91,7 @@ class StableDiffusionControlNetXSPipeline( DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin ): r""" + # todo Pipeline for text-to-image generation using Stable Diffusion with ControlNet-XS guidance. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods @@ -109,7 +110,7 @@ class StableDiffusionControlNetXSPipeline( A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. - controlnet ([`ControlNetXSModel`]): + controlnet_addon ([`ControlNetXSAddon`]): Provides additional conditioning to the `unet` during the denoising process. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of @@ -163,7 +164,7 @@ def __init__( ) = controlnet_addon._check_if_vae_compatible(vae) if not vae_compatible: raise ValueError( - f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXS model {cnxs_condition_downsample_factor} need to be equal. Consider building the ControlNetXS model with different `conditioning_block_sizes`." + f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXSAddon model ({cnxs_condition_downsample_factor}) need to be equal. Consider building the ControlNetXSAddon model with different `conditioning_embedding_out_channels`." ) self.register_modules( diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index 5caafc4ee48b..631b700a8a89 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -159,7 +159,7 @@ def __init__( ) = controlnet_addon._check_if_vae_compatible(vae) if not vae_compatible: raise ValueError( - f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXS model {cnxs_condition_downsample_factor} need to be equal. Consider building the ControlNetXS model with different `conditioning_block_sizes`." + f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXSAddon model ({cnxs_condition_downsample_factor}) need to be equal. Consider building the ControlNetXSAddon model with different `conditioning_embedding_out_channels`." ) self.register_modules( diff --git a/tests/pipelines/controlnet_xs/__init__.py b/tests/pipelines/controlnet_xs/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py new file mode 100644 index 000000000000..f40e75dcabe3 --- /dev/null +++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py @@ -0,0 +1,299 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import traceback +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + ControlNetXSAddon, + DDIMScheduler, + LCMScheduler, + StableDiffusionControlNetXSPipeline, + UNet2DConditionModel, +) +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.testing_utils import ( + enable_full_determinism, + load_image, + load_numpy, + require_python39_or_higher, + require_torch_2, + require_torch_gpu, + run_test_in_subprocess, + slow, + torch_device, +) +from diffusers.utils.torch_utils import randn_tensor + +from ..pipeline_params import ( + IMAGE_TO_IMAGE_IMAGE_PARAMS, + TEXT_TO_IMAGE_BATCH_PARAMS, + TEXT_TO_IMAGE_IMAGE_PARAMS, + TEXT_TO_IMAGE_PARAMS, +) +from ..test_pipelines_common import ( + PipelineKarrasSchedulerTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, +) + + +enable_full_determinism() + + +# Will be run via run_test_in_subprocess +def _test_stable_diffusion_compile(in_queue, out_queue, timeout): + error = None + try: + _ = in_queue.get(timeout=timeout) + + controlnet_addon = ControlNetXSAddon.from_pretrained("todo umer") + + pipe = StableDiffusionControlNetXSPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1", safety_checker=None, controlnet_addon=controlnet_addon + ) + pipe.to("cuda") + pipe.set_progress_bar_config(disable=None) + + pipe.controlnet.to(memory_format=torch.channels_last) + pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "bird" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ).resize((512, 512)) + + output = pipe(prompt, image, num_inference_steps=10, generator=generator, output_type="np") + image = output.images[0] + + assert image.shape == (512, 512, 3) + + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny_out_full.npy" + ) + expected_image = np.resize(expected_image, (512, 512, 3)) + + assert np.abs(expected_image - image).max() < 1.0 + + except Exception: + error = f"{traceback.format_exc()}" + + results = {"error": error} + out_queue.put(results, timeout=timeout) + out_queue.join() + + +class ControlNetXSPipelineFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): + pipeline_class = StableDiffusionControlNetXSPipeline + params = TEXT_TO_IMAGE_PARAMS + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + + def get_dummy_components(self, time_cond_proj_dim=None): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(4, 8), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + norm_num_groups=1, + time_cond_proj_dim=time_cond_proj_dim, + ) + torch.manual_seed(0) + controlnet_addon = ControlNetXSAddon.from_unet( + base_model=unet, + size_ratio=0.5, + num_attention_heads=2, + learn_time_embedding=True, + conditioning_embedding_out_channels=(16,32), + ) + torch.manual_seed(0) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[4, 8], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + norm_num_groups=2, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "controlnet_addon": controlnet_addon, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + controlnet_embedder_scale_factor = 2 + image = randn_tensor( + (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + generator=generator, + device=torch.device(device), + ) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + "image": image, + } + + return inputs + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=2e-3) + + def test_controlnet_lcm(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + components = self.get_dummy_components(time_cond_proj_dim=256) + sd_pipe = StableDiffusionControlNetXSPipeline(**components) + sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + output = sd_pipe(**inputs) + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array( + [0.52700454, 0.3930534, 0.25509018, 0.7132304, 0.53696585, 0.46568912, 0.7095368, 0.7059624, 0.4744786] + ) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + +@slow +@require_torch_gpu +class ControlNetXSPipelineSlowTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_canny(self): + controlnet_addon = ControlNetXSAddon.from_pretrained("UmerHA/Testing-ConrolNetXS-SD2.1-canny") + + pipe = StableDiffusionControlNetXSPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1", safety_checker=None, controlnet_addon=controlnet_addon + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "bird" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ) + + output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3) + + image = output.images[0] + + assert image.shape == (768, 512, 3) + + original_image = image[-3:, -3:, -1].flatten() + expected_image = np.array([0.1274, 0.1401, 0.147, 0.1185, 0.1555, 0.1492, 0.1565, 0.1474, 0.1701]) + assert np.allclose(original_image, expected_image, atol=1e-04) + + def test_depth(self): + controlnet_addon = ControlNetXSAddon.from_pretrained("todo umer") + + pipe = StableDiffusionControlNetXSPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1", safety_checker=None, controlnet_addon=controlnet_addon + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "Stormtrooper's lecture" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png" + ) + + output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3) + + image = output.images[0] + + assert image.shape == (512, 512, 3) + + original_image = image[-3:, -3:, -1].flatten() + expected_image = np.array([0.1098, 0.1025, 0.1211, 0.1129, 0.1165, 0.1262, 0.1185, 0.1261, 0.1703]) + assert np.allclose(original_image, expected_image, atol=1e-04) + + @require_python39_or_higher + @require_torch_2 + def test_stable_diffusion_compile(self): + run_test_in_subprocess(test_case=self, target_func=_test_stable_diffusion_compile, inputs=None) diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py new file mode 100644 index 000000000000..253b4ae4b0fe --- /dev/null +++ b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py @@ -0,0 +1,361 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + ControlNetXSAddon, + EulerDiscreteScheduler, + StableDiffusionXLControlNetXSPipeline, + UNet2DConditionModel, +) +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.testing_utils import enable_full_determinism, load_image, require_torch_gpu, slow, torch_device +from diffusers.utils.torch_utils import randn_tensor + +from ..pipeline_params import ( + IMAGE_TO_IMAGE_IMAGE_PARAMS, + TEXT_TO_IMAGE_BATCH_PARAMS, + TEXT_TO_IMAGE_IMAGE_PARAMS, + TEXT_TO_IMAGE_PARAMS, +) +from ..test_pipelines_common import ( + PipelineKarrasSchedulerTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, + SDXLOptionalComponentsTesterMixin, +) + + +enable_full_determinism() + + +class StableDiffusionXLControlNetXSPipelineFastTests( + PipelineLatentTesterMixin, + PipelineKarrasSchedulerTesterMixin, + PipelineTesterMixin, + SDXLOptionalComponentsTesterMixin, + unittest.TestCase, +): + pipeline_class = StableDiffusionXLControlNetXSPipeline + params = TEXT_TO_IMAGE_PARAMS + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + # SD2-specific config below + attention_head_dim=(2, 4), + use_linear_projection=True, + addition_embed_type="text_time", + addition_time_embed_dim=8, + transformer_layers_per_block=(1, 2), + projection_class_embeddings_input_dim=80, # 6 * 8 + 32 + cross_attention_dim=64, + ) + torch.manual_seed(0) + controlnet_addon = ControlNetXSAddon.from_unet( + base_model=unet, + size_ratio=0.5, + learn_time_embedding=True, + conditioning_embedding_out_channels=(16,32), + ) + torch.manual_seed(0) + scheduler = EulerDiscreteScheduler( + beta_start=0.00085, + beta_end=0.012, + steps_offset=1, + beta_schedule="scaled_linear", + timestep_spacing="leading", + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + # SD2-specific config below + hidden_act="gelu", + projection_dim=32, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "controlnet_addon": controlnet_addon, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "text_encoder_2": text_encoder_2, + "tokenizer_2": tokenizer_2, + } + return components + + # copied from test_controlnet_sdxl.py + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + controlnet_embedder_scale_factor = 2 + image = randn_tensor( + (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + generator=generator, + device=torch.device(device), + ) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "np", + "image": image, + } + + return inputs + + # copied from test_controlnet_sdxl.py + def test_attention_slicing_forward_pass(self): + return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3) + + # copied from test_controlnet_sdxl.py + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3) + + # copied from test_controlnet_sdxl.py + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=2e-3) + + # copied from test_controlnet_sdxl.py + def test_save_load_optional_components(self): + self._test_save_load_optional_components() + + # copied from test_controlnet_sdxl.py + @require_torch_gpu + def test_stable_diffusion_xl_offloads(self): + pipes = [] + components = self.get_dummy_components() + sd_pipe = self.pipeline_class(**components).to(torch_device) + pipes.append(sd_pipe) + + components = self.get_dummy_components() + sd_pipe = self.pipeline_class(**components) + sd_pipe.enable_model_cpu_offload() + pipes.append(sd_pipe) + + components = self.get_dummy_components() + sd_pipe = self.pipeline_class(**components) + sd_pipe.enable_sequential_cpu_offload() + pipes.append(sd_pipe) + + image_slices = [] + for pipe in pipes: + pipe.unet.set_default_attn_processor() + + inputs = self.get_dummy_inputs(torch_device) + image = pipe(**inputs).images + + image_slices.append(image[0, -3:, -3:, -1].flatten()) + + assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 + assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3 + + # copied from test_controlnet_sdxl.py + def test_stable_diffusion_xl_multi_prompts(self): + components = self.get_dummy_components() + sd_pipe = self.pipeline_class(**components).to(torch_device) + + # forward with single prompt + inputs = self.get_dummy_inputs(torch_device) + output = sd_pipe(**inputs) + image_slice_1 = output.images[0, -3:, -3:, -1] + + # forward with same prompt duplicated + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt_2"] = inputs["prompt"] + output = sd_pipe(**inputs) + image_slice_2 = output.images[0, -3:, -3:, -1] + + # ensure the results are equal + assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 + + # forward with different prompt + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt_2"] = "different prompt" + output = sd_pipe(**inputs) + image_slice_3 = output.images[0, -3:, -3:, -1] + + # ensure the results are not equal + assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4 + + # manually set a negative_prompt + inputs = self.get_dummy_inputs(torch_device) + inputs["negative_prompt"] = "negative prompt" + output = sd_pipe(**inputs) + image_slice_1 = output.images[0, -3:, -3:, -1] + + # forward with same negative_prompt duplicated + inputs = self.get_dummy_inputs(torch_device) + inputs["negative_prompt"] = "negative prompt" + inputs["negative_prompt_2"] = inputs["negative_prompt"] + output = sd_pipe(**inputs) + image_slice_2 = output.images[0, -3:, -3:, -1] + + # ensure the results are equal + assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 + + # forward with different negative_prompt + inputs = self.get_dummy_inputs(torch_device) + inputs["negative_prompt"] = "negative prompt" + inputs["negative_prompt_2"] = "different negative prompt" + output = sd_pipe(**inputs) + image_slice_3 = output.images[0, -3:, -3:, -1] + + # ensure the results are not equal + assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4 + + # copied from test_stable_diffusion_xl.py + def test_stable_diffusion_xl_prompt_embeds(self): + components = self.get_dummy_components() + sd_pipe = self.pipeline_class(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + # forward without prompt embeds + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt"] = 2 * [inputs["prompt"]] + inputs["num_images_per_prompt"] = 2 + + output = sd_pipe(**inputs) + image_slice_1 = output.images[0, -3:, -3:, -1] + + # forward with prompt embeds + inputs = self.get_dummy_inputs(torch_device) + prompt = 2 * [inputs.pop("prompt")] + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = sd_pipe.encode_prompt(prompt) + + output = sd_pipe( + **inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + ) + image_slice_2 = output.images[0, -3:, -3:, -1] + + # make sure that it's equal + assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1.1e-4 + + +@slow +@require_torch_gpu +class ControlNetSDXLPipelineXSSlowTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_canny(self): + controlnet_addon = ControlNetXSAddon.from_pretrained("UmerHA/Testing-ConrolNetXS-SDXL-canny") + + pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", controlnet_addon=controlnet_addon + ) + pipe.enable_sequential_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "bird" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ) + + images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images + + assert images[0].shape == (768, 512, 3) + + original_image = images[0, -3:, -3:, -1].flatten() + expected_image = np.array([0.4359, 0.4335, 0.4609, 0.4515, 0.4669, 0.4494, 0.452, 0.4493, 0.4382]) + assert np.allclose(original_image, expected_image, atol=1e-04) + + def test_depth(self): + controlnet_addon = ControlNetXSAddon.from_pretrained("todo umer") + + pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", controlnet_addon=controlnet_addon + ) + pipe.enable_sequential_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "Stormtrooper's lecture" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png" + ) + + images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images + + assert images[0].shape == (512, 512, 3) + + original_image = images[0, -3:, -3:, -1].flatten() + expected_image = np.array([0.4411, 0.3617, 0.2654, 0.266, 0.3449, 0.3898, 0.3745, 0.353, 0.326]) + assert np.allclose(original_image, expected_image, atol=1e-04) From 1d8d5b82933e7203e28f64b95d7246733f3448b2 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Thu, 18 Jan 2024 17:31:35 +0100 Subject: [PATCH 24/75] Removed debug logs --- src/diffusers/models/attention.py | 8 - src/diffusers/models/controlnet_xs.py | 44 ----- src/diffusers/models/resnet.py | 16 -- src/diffusers/models/transformer_2d.py | 8 - src/diffusers/umer_debug_logger.py | 235 ------------------------- 5 files changed, 311 deletions(-) delete mode 100644 src/diffusers/umer_debug_logger.py diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 59093659105e..0fbe9beb1ddd 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -17,7 +17,6 @@ import torch.nn.functional as F from torch import nn -from ..umer_debug_logger import udl from ..utils import USE_PEFT_BACKEND from ..utils.torch_utils import maybe_allow_in_graph from .activations import GEGLU, GELU, ApproximateGELU @@ -343,8 +342,6 @@ def forward( if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) - udl.log_if("attn: attn1", attn_output, udl.SUBBLOCKM1) - udl.log_if("attn: add attn1", hidden_states, udl.SUBBLOCKM1) # 2.5 GLIGEN Control if gligen_kwargs is not None: @@ -375,8 +372,6 @@ def forward( **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states - udl.log_if("attn: attn2", attn_output, udl.SUBBLOCKM1) - udl.log_if("attn: add attn2", hidden_states, udl.SUBBLOCKM1) # 4. Feed-forward if self.use_ada_layer_norm_continuous: @@ -408,9 +403,6 @@ def forward( if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) - udl.log_if("attn: ff", ff_output, udl.SUBBLOCKM1) - udl.log_if("attn: add ff", hidden_states, udl.SUBBLOCKM1) - return hidden_states diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index c7d9426abd93..9e14b05bf11c 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -7,7 +7,6 @@ from torch import nn from torch.nn import functional as F -from ..umer_debug_logger import udl from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, is_torch_version, logging from .autoencoders import AutoencoderKL @@ -673,14 +672,6 @@ def forward( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) - sample, timesteps, encoder_hidden_states, controlnet_cond = udl.do_input_action( - x=sample, - t=timesteps, - xcross=encoder_hidden_states, - hint=controlnet_cond, - ) - - t_emb = self.base_time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors @@ -728,20 +719,10 @@ def forward( time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(temb.dtype) - - add_embeds = udl.do_input_action_for_do_input_action(add_embeds) - aug_emb = self.base_add_embedding(add_embeds) else: raise NotImplementedError() - udl.stop_if(udl.INPUT_SAVE, 'Stopping because I only wanted to save input') - - udl.log_if("sample", sample, udl.SUBBLOCK) - udl.log_if("timestep", timesteps, udl.SUBBLOCK) - udl.log_if("encoder_hidden_states", encoder_hidden_states, udl.SUBBLOCK) - udl.log_if("controlnet_cond", controlnet_cond, udl.SUBBLOCK) - temb = temb + aug_emb if aug_emb is not None else temb # text embeddings @@ -753,9 +734,6 @@ def forward( h_ctrl = h_base = sample hs_base, hs_ctrl = [], [] - udl.log_if("h_ctrl", h_ctrl, udl.SUBBLOCK) - udl.log_if("h_base", h_base, udl.SUBBLOCK) - # Cross Control # 1 - conv in & down # The base -> ctrl connections are "delayed" by 1 subblock, because we want to "wait" to ensure the new information from the last ctrl -> base connection is also considered @@ -764,13 +742,10 @@ def forward( # base -> ctrl: | subblock 1 | ... | subblock n | mid block h_base = self.base_conv_in(h_base) - udl.log_if("base", h_base, udl.SUBBLOCK) h_ctrl = self.ctrl_conv_in(h_ctrl) - udl.log_if("ctrl", h_ctrl, udl.SUBBLOCK) if guided_hint is not None: h_ctrl += guided_hint h_base = h_base + self.down_zero_convs_c2b[0](h_ctrl) * conditioning_scale # add ctrl -> base - udl.log_if("add c2b", h_base, udl.SUBBLOCK) hs_base.append(h_base) hs_ctrl.append(h_ctrl) @@ -787,49 +762,30 @@ def forward( additional_params = [] h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) # concat base -> ctrl - udl.log_if("concat b2c", h_ctrl, udl.SUBBLOCK) - h_base = b(h_base, *additional_params) # apply base subblock - udl.log_if("base", h_base, udl.SUBBLOCK) - h_ctrl = c(h_ctrl, *additional_params) # apply ctrl subblock - udl.log_if("ctrl", h_ctrl, udl.SUBBLOCK) - h_base = h_base + c2b(h_ctrl) * conditioning_scale # add ctrl -> base - udl.log_if("add c2b", h_base, udl.SUBBLOCK) hs_base.append(h_base) hs_ctrl.append(h_ctrl) h_ctrl = torch.cat([h_ctrl, self.down_zero_convs_b2c[-1](h_base)], dim=1) # concat base -> ctrl - udl.log_if("concat b2c", h_ctrl, udl.SUBBLOCK) # 2 - mid h_base = self.base_mid_block(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # apply base subblock - udl.log_if("base", h_base, udl.SUBBLOCK) - h_ctrl = self.ctrl_mid_block(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # apply ctrl subblock - udl.log_if("ctrl", h_ctrl, udl.SUBBLOCK) - h_base = h_base + self.mid_zero_convs_c2b(h_ctrl) * conditioning_scale # add ctrl -> base - udl.log_if("add c2b", h_base, udl.SUBBLOCK) # 3 - up for b, c2b, skip_c, skip_b in zip( self.base_up_subblocks, self.up_zero_convs_c2b, reversed(hs_ctrl), reversed(hs_base) ): h_base = h_base + c2b(skip_c) * conditioning_scale # add info from ctrl encoder - udl.log_if("add c2b", h_base, udl.SUBBLOCK) - h_base = torch.cat([h_base, skip_b], dim=1) # concat info from base encoder+ctrl encoder h_base = b(h_base, temb, cemb, attention_mask, cross_attention_kwargs) - udl.log_if("base", h_base, udl.SUBBLOCK) h_base = self.base_conv_norm_out(h_base) h_base = self.base_conv_act(h_base) h_base = self.base_conv_out(h_base) - udl.log_if("conv_out", h_base, udl.SUBBLOCK) - - udl.stop_if(udl.SUBBLOCK, 'It is done, my dude. Let us look at these tensors.') if not return_dict: return h_base diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index a8be1c224cf7..bbfb71ca3fbf 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -20,7 +20,6 @@ import torch.nn as nn import torch.nn.functional as F -from ..umer_debug_logger import udl from ..utils import USE_PEFT_BACKEND from .activations import get_activation from .attention_processor import SpatialNorm @@ -189,16 +188,12 @@ def forward( ) -> torch.FloatTensor: hidden_states = input_tensor - udl.log_if("res: input", hidden_states, udl.SUBBLOCKM1) - if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": hidden_states = self.norm1(hidden_states, temb) else: hidden_states = self.norm1(hidden_states) - udl.log_if("res: norm1", hidden_states, udl.SUBBLOCKM1) hidden_states = self.nonlinearity(hidden_states) - udl.log_if("res: nonlin", hidden_states, udl.SUBBLOCKM1) if self.upsample is not None: # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 @@ -227,10 +222,7 @@ def forward( else self.downsample(hidden_states) ) - udl.log_if("res: updown", hidden_states, udl.SUBBLOCKM1) - hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states) - udl.log_if("res: conv1", hidden_states, udl.SUBBLOCKM1) if self.time_emb_proj is not None: if not self.skip_time_act: @@ -241,13 +233,9 @@ def forward( else self.time_emb_proj(temb)[:, :, None, None] ) - udl.log_if("res: temb", temb, udl.SUBBLOCKM1) - if temb is not None and self.time_embedding_norm == "default": hidden_states = hidden_states + temb - udl.log_if("res: add temb", hidden_states, udl.SUBBLOCKM1) - if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": hidden_states = self.norm2(hidden_states, temb) else: @@ -262,8 +250,6 @@ def forward( hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states) - udl.log_if("res: conv2", hidden_states, udl.SUBBLOCKM1) - if self.conv_shortcut is not None: input_tensor = ( self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor) @@ -271,8 +257,6 @@ def forward( output_tensor = (input_tensor + hidden_states) / self.output_scale_factor - udl.log_if("res: out", output_tensor, udl.SUBBLOCKM1) - return output_tensor diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index bf7942d993f8..f97c3dbebe2c 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -18,7 +18,6 @@ import torch.nn.functional as F from torch import nn -from ..umer_debug_logger import udl from ..configuration_utils import ConfigMixin, register_to_config from ..models.embeddings import ImagePositionalEmbeddings from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version @@ -326,7 +325,6 @@ def forward( residual = hidden_states hidden_states = self.norm(hidden_states) - udl.log_if("attn: norm", hidden_states, udl.SUBBLOCKM1) if not self.use_linear_projection: hidden_states = ( @@ -345,13 +343,9 @@ def forward( else self.proj_in(hidden_states) ) - udl.log_if("attn: proj_in", hidden_states, udl.SUBBLOCKM1) - elif self.is_input_vectorized: - print("umer: wtf, this happened?") hidden_states = self.latent_image_embedding(hidden_states) elif self.is_input_patches: - print("umer: wtf, why did this happen?") height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size hidden_states = self.pos_embed(hidden_states) @@ -462,8 +456,6 @@ def custom_forward(*inputs): shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) ) - udl.log_if("attn: proj_out", output, udl.SUBBLOCKM1) - if not return_dict: return (output,) diff --git a/src/diffusers/umer_debug_logger.py b/src/diffusers/umer_debug_logger.py deleted file mode 100644 index f9cddc831d6d..000000000000 --- a/src/diffusers/umer_debug_logger.py +++ /dev/null @@ -1,235 +0,0 @@ -# Logger to help me (UmerHA) debug controlnet-xs - -import csv -import inspect -import os -import shutil -from datetime import datetime -from types import SimpleNamespace - -import torch - - -class UmerDebugLogger: - _FILE = "udl.csv" - - INPUT_SAVE = 'input_save' - BLOCK = 'block' - SUBBLOCK = 'subblock' - SUBBLOCKM1 = 'subblock-minus-1' - allowed_conditions = [INPUT_SAVE, BLOCK, SUBBLOCK, SUBBLOCKM1] - - input_files = None - - def __init__(self, log_dir="logs", condition=None): - self.log_dir, self.condition, self.tensor_counter = log_dir, condition, 0 - os.makedirs(log_dir, exist_ok=True) - self.fields = ["timestamp", "cls", "fn", "shape", "msg", "condition", "tensor_file"] - self.create_file() - self.warned_of_no_condition = False - print( - "Info: `UmerDebugLogger` created. This is a logging class that will be deleted when the PR to integrate ControlNet-XS is done." - ) - - @property - def full_file_path(self): - return os.path.join(self.log_dir, self._FILE) - - def create_file(self): - file = self.full_file_path - if not os.path.isfile(file): - with open(file, "w", newline="") as f: - writer = csv.DictWriter(f, fieldnames=self.fields) - writer.writeheader() - - def set_dir(self, log_dir, clear=False): - self.log_dir = log_dir - if clear: - self.clear_logs() - self.create_file() - - def clear_logs(self): - shutil.rmtree(self.log_dir, ignore_errors=True) - os.makedirs(self.log_dir, exist_ok=True) - self.create_file() - - def set_condition(self, condition): - if not isinstance(condition, list): condition = [condition] - self.condition = condition - - def check_condition(self, condition): - if not condition in self.allowed_conditions: raise ValueError(f'Unknown condition: {condition}') - return condition in self.condition - - def log_if(self, msg, t, condition, *, print_=False): - self.maybe_warn_of_no_condition() - - if not self.check_condition(condition): - return - - # Use inspect to get the current frame and then go back one level to find caller - frame = inspect.currentframe() - caller_frame = frame.f_back - caller_info = inspect.getframeinfo(caller_frame) - - # Extract class and function name from the caller - cls_name = ( - caller_frame.f_locals.get("self", None).__class__.__name__ if "self" in caller_frame.f_locals else None - ) - function_name = caller_info.function - - if not hasattr(t, "shape"): - t = torch.tensor(t) - t = t.cpu().detach() - - # Save tensor to a file - tensor_filename = f"tensor_{self.tensor_counter}.pt" - torch.save(t, os.path.join(self.log_dir, tensor_filename)) - self.tensor_counter += 1 - - # Log information to CSV - log_info = { - "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - "cls": cls_name, - "fn": function_name, - "shape": str(list(t.shape)), - "msg": msg, - "condition": condition, - "tensor_file": tensor_filename, - } - - with open(self.full_file_path, "a", newline="") as f: - writer = csv.DictWriter(f, fieldnames=self.fields) - writer.writerow(log_info) - - if print_: - print(f"{msg}\t{t.flatten()[:10]}") - - def print_if(self, msg, conditions, end="\n"): - self.maybe_warn_of_no_condition() - if not isinstance(conditions, (tuple, list)): - conditions = [conditions] - if any(self.condition == c for c in conditions): - print(msg, end=end) - - def stop_if(self, condition, funny_msg): - if self.check_condition(condition): - current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - raise SystemExit(f"{funny_msg} - {current_time}") - - def maybe_warn_of_no_condition(self): - if self.condition is None and not self.warned_of_no_condition: - print("Info: No condition set for UmerDebugLogger") - self.warned_of_no_condition = True - - def get_log_objects(self): - log_objects = [] - file = self.full_file_path - with open(file, newline="") as f: - reader = csv.DictReader(f) - for row in reader: - row["tensor"] = torch.load(os.path.join(self.log_dir, row["tensor_file"])) - row["head"] = row["tensor"].flatten()[:10] - del row["tensor_file"] - log_objects.append(SimpleNamespace(**row)) - return log_objects - - @classmethod - def load_log_objects_from_dir(self, log_dir): - file = os.path.join(log_dir, self._FILE) - log_objects = [] - with open(file, newline="") as f: - reader = csv.DictReader(f) - for row in reader: - row["t"] = torch.load(os.path.join(log_dir, row["tensor_file"])) - row["head"] = row["t"].flatten()[:10] - del row["tensor_file"] - log_objects.append(SimpleNamespace(**row)) - return log_objects - - def save_input(self, dir_, x, t, xcross, hint, add_embeds=None, minimize_bs=True): - is_sdxl = add_embeds is not None - inputs = dict( - x=os.path.join(dir_, x), - t=os.path.join(dir_, t), - xcross=os.path.join(dir_, xcross), - hint=os.path.join(dir_,hint) - ) - if is_sdxl: inputs['add_embeds']=os.path.join(dir_, add_embeds) - self.input_files = SimpleNamespace(**inputs) - self.input_action = 'save' - self.minimize_bs = minimize_bs - - def load_input(self, dir_, x, t, xcross, hint, add_embeds=None): - is_sdxl = add_embeds is not None - inputs = dict( - x=os.path.join(dir_, x), - t=os.path.join(dir_, t), - xcross=os.path.join(dir_, xcross), - hint=os.path.join(dir_,hint) - ) - if is_sdxl:inputs['add_embeds']=os.path.join(dir_, add_embeds) - self.input_files = SimpleNamespace(**inputs) - self.input_action = 'load' - - def dont_process_input(self): - self.input_action = 'none' - self.input_files = {} - - def do_input_action(self, x, t, xcross, hint): - assert self.input_files is not None, "self.input_files not set! Use `save_input`, `load_input` or `dont_process_input`" - assert self.input_action in ['save', 'load', 'none'] - - if self.input_action == 'save': - assert x.shape[0]==t.shape[0]==xcross.shape[0]==hint.shape[0] - - bs = x.shape[0] - if self.minimize_bs and bs > 1: - print(f'[udl] Input has batch size {bs} but reducing to 1 before saving') - x = x[0:1] - t = t[0:1] - xcross = xcross[0:1] - hint = hint[0:1] - - torch.save(x, self.input_files.x) - torch.save(t, self.input_files.t) - torch.save(xcross, self.input_files.xcross) - torch.save(hint, self.input_files.hint) - - print(f'[udl] Input saved (batch size = {x.shape[0]})') - - elif self.input_action == 'load': - x = torch.load(self.input_files.x, map_location=x.device) - t = torch.load( self.input_files.t, map_location=t.device) - xcross = torch.load(self.input_files.xcross, map_location=xcross.device) - hint = torch.load(self.input_files.hint, map_location=hint.device) - - assert x.shape[0]==t.shape[0]==xcross.shape[0]==hint.shape[0] - - print(f'[udl] Input loaded (batch size = {x.shape[0]})') - else: - print(f'[udl] Neither saving nor loading input (batch size = {x.shape[0]})') - return x, t, xcross, hint - - def do_input_action_for_do_input_action(self, add_embeds): - assert self.input_files is not None, "self.input_files not set! Use `save_input`, `load_input` or `dont_process_input`" - assert self.input_action in ['save', 'load', 'none'] - - if self.input_action == 'save': - bs = add_embeds.shape[0] - if self.minimize_bs and bs > 1: - print(f'[udl] Input `add_embeds` has batch size {bs} but reducing to 1 before saving `add_embeds`') - add_embeds = add_embeds[0:1] - torch.save(add_embeds, self.input_files.add_embeds) - print(f'[udl] Input `add_embeds` saved (batch size = {add_embeds.shape[0]})') - - elif self.input_action == 'load': - add_embeds = torch.load(self.input_files.add_embeds, map_location=add_embeds.device) - print(f'[udl] Input loaded (batch size = {add_embeds.shape[0]})') - - else: - print(f'[udl] Neither saving nor loading input (batch size = {add_embeds.shape[0]})') - - return add_embeds - -udl = UmerDebugLogger() From 55b73500a0a7e934ae563db5cb3cb61576522cda Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Fri, 19 Jan 2024 12:57:35 +0100 Subject: [PATCH 25/75] Fixed Slow Tests --- src/diffusers/models/attention.py | 1 - src/diffusers/models/controlnet_xs.py | 64 ++++++++----------- src/diffusers/models/transformer_2d.py | 3 - .../controlnet_xs/test_controlnetxs.py | 4 +- .../controlnet_xs/test_controlnetxs_sdxl.py | 8 ++- 5 files changed, 34 insertions(+), 46 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 0fbe9beb1ddd..fc4564c3a6ff 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -342,7 +342,6 @@ def forward( if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) - # 2.5 GLIGEN Control if gligen_kwargs is not None: hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 9e14b05bf11c..12cd0263b02b 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -85,7 +85,7 @@ def forward(self, conditioning): class ControlNetXSAddon(ModelMixin, ConfigMixin): r""" A `ControlNetXSAddon` model. To use it, pass it into a `ControlNetXSModel` (together with a `UNet2DConditionModel` base model). - + This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). @@ -105,11 +105,6 @@ class ControlNetXSAddon(ModelMixin, ConfigMixin): Whether the time embedding should be learned or fixed. channels_base (`Dict[str, List[Tuple[int]]]`): todo Base channel configurations for the model's layers. - addition_embed_type (defaults to `None`): - Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or - "text_time". - addition_time_embed_dim (defaults to `None`): - Dimension for the timestep embeddings. attention_head_dim (`list[int]`, defaults to `[4]`): The dimension of the attention heads. block_out_channels (`list[int]`, defaults to `[4, 8, 16, 16]`): @@ -118,8 +113,6 @@ class ControlNetXSAddon(ModelMixin, ConfigMixin): The dimension of the cross attention features. down_block_types (`list[str]`, defaults to `["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"]`): The tuple of downsample blocks to use. - projection_class_embeddings_input_dim (defaults to `None`): - The dimension of the `class_labels` input when sample_size (`int`, defaults to 96): Height and width of input/output sample. transformer_layers_per_block (`Union[int, Tuple[int]]`, defaults to 1): @@ -151,14 +144,14 @@ def gather_base_subblock_sizes(blocks_sizes: List[int]): # down_out for b in range(n_blocks): for i in range(n_subblocks_per_block): - if b==n_blocks-1 and i==2: + if b == n_blocks - 1 and i == 2: # Last block has no downsampler, so there are only 2 subblocks instead of 3 continue # The input channels are changed by the first resnet, which is in the first subblock. - if i==0: + if i == 0: # Same input channels - down_out.append(blocks_sizes[max(b-1,0)]) + down_out.append(blocks_sizes[max(b - 1, 0)]) else: # Changed input channels down_out.append(blocks_sizes[b]) @@ -170,9 +163,9 @@ def gather_base_subblock_sizes(blocks_sizes: List[int]): for b in range(len(rev_blocks_sizes)): for i in range(n_subblocks_per_block): # The input channels are changed by the first resnet, which is in the first subblock. - if i==0: + if i == 0: # Same input channels - up_in.append(rev_blocks_sizes[max(b-1,0)]) + up_in.append(rev_blocks_sizes[max(b - 1, 0)]) else: # Changed input channels up_in.append(rev_blocks_sizes[b]) @@ -224,25 +217,28 @@ def from_unet( block_out_channels = [int(b * size_ratio) for b in base_model.config.block_out_channels] if num_attention_heads is None: - num_attention_heads = base_model.config.num_attention_heads + # The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. + num_attention_heads = base_model.config.attention_head_dim norm_num_groups = math.gcd(*block_out_channels) + time_embedding_input_dim = base_model.time_embedding.linear_1.in_features + time_embedding_dim = base_model.time_embedding.linear_1.out_features + return ControlNetXSAddon( learn_time_embedding=learn_time_embedding, channels_base=channels_base, - addition_embed_type=base_model.config.addition_embed_type, - addition_time_embed_dim=base_model.config.addition_time_embed_dim, attention_head_dim=num_attention_heads, block_out_channels=block_out_channels, cross_attention_dim=base_model.config.cross_attention_dim, down_block_types=base_model.config.down_block_types, - projection_class_embeddings_input_dim=base_model.config.projection_class_embeddings_input_dim, sample_size=base_model.config.sample_size, transformer_layers_per_block=base_model.config.transformer_layers_per_block, upcast_attention=base_model.config.upcast_attention, norm_num_groups=norm_num_groups, conditioning_embedding_out_channels=conditioning_embedding_out_channels, + time_embedding_input_dim=time_embedding_input_dim, + time_embedding_dim=time_embedding_dim, ) @register_to_config @@ -251,21 +247,18 @@ def __init__( conditioning_channels: int = 3, conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), - time_embedding_input_dim: int = 320, - time_embedding_dim: int = 1280, + time_embedding_input_dim: Optional[int] = 320, + time_embedding_dim: Optional[int] = 1280, learn_time_embedding: bool = False, channels_base: Dict[str, List[Tuple[int]]] = { "down - out": [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280], "mid - out": 1280, "up - in": [1280, 1280, 1280, 1280, 1280, 1280, 1280, 640, 640, 640, 320, 320], }, - addition_embed_type=None, - addition_time_embed_dim=None, attention_head_dim=[4], block_out_channels=[4, 8, 16, 16], cross_attention_dim=1024, down_block_types=["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"], - projection_class_embeddings_input_dim=None, sample_size=96, transformer_layers_per_block: Union[int, Tuple[int]] = 1, upcast_attention=True, @@ -292,19 +285,12 @@ def __init__( # time if learn_time_embedding: - time_embedding_dim = time_embedding_dim or block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0) self.time_embedding = TimestepEmbedding(time_embedding_input_dim, time_embedding_dim) else: self.time_proj = None self.time_embedding = None - if addition_embed_type == "text_time": - self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0) - self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embedding_dim) - elif addition_embed_type is not None: - raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") - self.time_embed_act = None self.down_subblocks = nn.ModuleList([]) @@ -328,7 +314,7 @@ def __init__( self.down_subblocks.append( CrossAttnSubBlock2D( has_crossattn=use_crossattention, - in_channels=input_channel + channels_base['down - out'][subblock_counter], + in_channels=input_channel + channels_base["down - out"][subblock_counter], out_channels=output_channel, temb_channels=time_embedding_dim, transformer_layers_per_block=transformer_layers_per_block[i], @@ -342,7 +328,7 @@ def __init__( self.down_subblocks.append( CrossAttnSubBlock2D( has_crossattn=use_crossattention, - in_channels=output_channel + channels_base['down - out'][subblock_counter], + in_channels=output_channel + channels_base["down - out"][subblock_counter], out_channels=output_channel, temb_channels=time_embedding_dim, transformer_layers_per_block=transformer_layers_per_block[i], @@ -356,14 +342,14 @@ def __init__( if i < len(down_block_types) - 1: self.down_subblocks.append( DownSubBlock2D( - in_channels=output_channel + channels_base['down - out'][subblock_counter], + in_channels=output_channel + channels_base["down - out"][subblock_counter], out_channels=output_channel, ) ) subblock_counter += 1 # mid - mid_in_channels = block_out_channels[-1] + channels_base['down - out'][subblock_counter] + mid_in_channels = block_out_channels[-1] + channels_base["down - out"][subblock_counter] mid_out_channels = block_out_channels[-1] self.mid_block = UNetMidBlock2DCrossAttn( @@ -398,7 +384,7 @@ def __init__( # todo - better comment # Information is passed from base to ctrl _before_ each subblock. We therefore use the 'in' channels. # As the information is concatted in ctrl, we don't need to change channel sizes. So channels in = channels out. - for c in channels_base["down - out"]: # change down - in to down - out + for c in channels_base["down - out"]: # change down - in to down - out self.down_zero_convs_b2c.append(self._make_zero_conv(c, c)) # 4.2 - Connections from ctrl encoder to base encoder @@ -721,7 +707,9 @@ def forward( add_embeds = add_embeds.to(temb.dtype) aug_emb = self.base_add_embedding(add_embeds) else: - raise NotImplementedError() + raise ValueError( + f"ControlNet-XS currently only supports StableDiffusion and StableDiffusion-XL, so addition_embed_type = {self.base_addition_embed_type} is currently not supported." + ) temb = temb + aug_emb if aug_emb is not None else temb @@ -768,7 +756,7 @@ def forward( hs_base.append(h_base) hs_ctrl.append(h_ctrl) - h_ctrl = torch.cat([h_ctrl, self.down_zero_convs_b2c[-1](h_base)], dim=1) # concat base -> ctrl + h_ctrl = torch.cat([h_ctrl, self.down_zero_convs_b2c[-1](h_base)], dim=1) # concat base -> ctrl # 2 - mid h_base = self.base_mid_block(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # apply base subblock @@ -959,8 +947,8 @@ def forward( class CrossAttnUpSubBlock2D(nn.Module): def __init__(self): """ - In the context of ControlNet-XS, `CrossAttnUpSubBlock2D` are only loaded from existing modules, and not created from scratch. - Therefore, `__init__` is left almost empty. + In the context of ControlNet-XS, `CrossAttnUpSubBlock2D` are only loaded from existing modules, and not created from scratch. + Therefore, `__init__` is left almost empty. """ super().__init__() self.gradient_checkpointing = False diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index f97c3dbebe2c..128395cc161a 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -325,7 +325,6 @@ def forward( residual = hidden_states hidden_states = self.norm(hidden_states) - if not self.use_linear_projection: hidden_states = ( self.proj_in(hidden_states, scale=lora_scale) @@ -359,8 +358,6 @@ def forward( timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype ) - - # 2. Blocks if self.caption_projection is not None: batch_size = hidden_states.shape[0] diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py index f40e75dcabe3..cdc5c6a1df09 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py @@ -111,6 +111,8 @@ class ControlNetXSPipelineFastTests( image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + test_attention_slicing = False + def get_dummy_components(self, time_cond_proj_dim=None): torch.manual_seed(0) unet = UNet2DConditionModel( @@ -131,7 +133,7 @@ def get_dummy_components(self, time_cond_proj_dim=None): size_ratio=0.5, num_attention_heads=2, learn_time_embedding=True, - conditioning_embedding_out_channels=(16,32), + conditioning_embedding_out_channels=(16, 32), ) torch.manual_seed(0) scheduler = DDIMScheduler( diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py index 253b4ae4b0fe..b27f0e88fcf2 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py @@ -61,6 +61,8 @@ class StableDiffusionXLControlNetXSPipelineFastTests( image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + test_attention_slicing = False + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( @@ -84,8 +86,8 @@ def get_dummy_components(self): controlnet_addon = ControlNetXSAddon.from_unet( base_model=unet, size_ratio=0.5, - learn_time_embedding=True, - conditioning_embedding_out_channels=(16,32), + learn_time_embedding=True, + conditioning_embedding_out_channels=(16, 32), ) torch.manual_seed(0) scheduler = EulerDiscreteScheduler( @@ -308,7 +310,7 @@ def test_stable_diffusion_xl_prompt_embeds(self): @slow @require_torch_gpu -class ControlNetSDXLPipelineXSSlowTests(unittest.TestCase): +class StableDiffusionXLControlNetXSPipelineSlowTests(unittest.TestCase): def tearDown(self): super().tearDown() gc.collect() From 6214da71dcf169639ee92c135bc63158896a5f99 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Fri, 19 Jan 2024 13:51:19 +0100 Subject: [PATCH 26/75] Added value ckecks | Updated model_cpu_offload_seq --- src/diffusers/models/controlnet_xs.py | 31 ++++++++++++++----- .../controlnet_xs/pipeline_controlnet_xs.py | 2 +- .../pipeline_controlnet_xs_sd_xl.py | 5 ++- 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 12cd0263b02b..2471954bf52f 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -255,14 +255,14 @@ def __init__( "mid - out": 1280, "up - in": [1280, 1280, 1280, 1280, 1280, 1280, 1280, 640, 640, 640, 320, 320], }, - attention_head_dim=[4], - block_out_channels=[4, 8, 16, 16], - cross_attention_dim=1024, - down_block_types=["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"], - sample_size=96, + attention_head_dim: Union[int, Tuple[int]] = 4, + block_out_channels : Tuple[int] = (4, 8, 16, 16), + cross_attention_dim: int =1024, + down_block_types: Tuple[str]=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"), + sample_size: Optional[int]=96, # todo understand transformer_layers_per_block: Union[int, Tuple[int]] = 1, - upcast_attention=True, - norm_num_groups=32, + upcast_attention: bool = True, + norm_num_groups: int = 32, # todo: rename max_norm_num_groups? ): super().__init__() @@ -278,7 +278,22 @@ def __init__( # Check inputs if conditioning_channel_order not in ["rgb", "bgr"]: raise ValueError(f"unknown `conditioning_channel_order`: {conditioning_channel_order}") - # todo - other checks + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + # todo: attention_head_dim can be int, not list(int) + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) # input self.conv_in = nn.Conv2d(4, block_out_channels[0], kernel_size=3, padding=1) diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index f51ef5df721f..90932902a485 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -123,7 +123,7 @@ class StableDiffusionControlNetXSPipeline( A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ - model_cpu_offload_seq = "text_encoder->unet->vae>controlnet" + model_cpu_offload_seq = "text_encoder->controlnet->vae" _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index 631b700a8a89..fd4eb796e7ad 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -133,8 +133,7 @@ class StableDiffusionXLControlNetXSPipeline( watermarker is used. """ - # leave controlnet out on purpose because it iterates with unet - model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae->controlnet" + model_cpu_offload_seq = "text_encoder->text_encoder_2->controlnet->vae" _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] def __init__( @@ -172,7 +171,7 @@ def __init__( controlnet_addon=controlnet_addon, scheduler=scheduler, ) - self.controlnet = ControlNetXSModel(base_model=unet, ctrl_model=controlnet_addon) + self.controlnet = ControlNetXSModel(base_model=unet, ctrl_model=controlnet_addon) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( From 54bb91db7b40449b2aef656a9a779e34f9ba97d2 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Thu, 25 Jan 2024 19:06:12 +0100 Subject: [PATCH 27/75] accelerate-offloading works ; fast tests work --- src/diffusers/models/controlnet_xs.py | 14 ++- .../controlnet/pipeline_controlnet.py | 6 +- .../controlnet_xs/pipeline_controlnet_xs.py | 85 ++++++++++++++++--- .../pipeline_controlnet_xs_sd_xl.py | 55 +++++++----- .../controlnet_xs/test_controlnetxs.py | 13 ++- .../controlnet_xs/test_controlnetxs_sdxl.py | 12 ++- 6 files changed, 147 insertions(+), 38 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 2471954bf52f..d0839bcbcc61 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -520,10 +520,12 @@ def __init__( ): super().__init__() + self.condition_downscale_factor = 2 ** (len(ctrl_model.config.conditioning_embedding_out_channels) - 1) + # 1 - Save options + self.class_embed_type = base_model.config.class_embed_type self.use_ctrl_time_embedding = ctrl_model.config.learn_time_embedding self.conditioning_channel_order = ctrl_model.config.conditioning_channel_order - self.class_embed_type = base_model.config.class_embed_type # 2 - Save control model parts self.ctrl_time_embedding = ctrl_model.time_embedding @@ -539,6 +541,7 @@ def __init__( self.up_zero_convs_c2b = ctrl_model.up_zero_convs_c2b # 4 - Save base model parts + self.base_in_channels = base_model.config.in_channels self.base_time_proj = base_model.time_proj self.base_time_embedding = base_model.time_embedding self.base_class_embedding = base_model.class_embedding @@ -553,6 +556,8 @@ def __init__( self.base_add_time_proj = base_model.add_time_proj if hasattr(base_model, "add_embedding"): self.base_add_embedding = base_model.add_embedding + if hasattr(base_model.config, "addition_time_embed_dim"): + self.base_addition_time_embed_dim = base_model.config.addition_time_embed_dim # 4.2 - Decompose blocks of base model into subblocks for block in base_model.down_blocks: @@ -594,6 +599,13 @@ def __init__( self.time_embedding_mix = time_embedding_mix + @torch.no_grad() + def _check_if_vae_compatible(self, vae: AutoencoderKL): + condition_downscale_factor = self.condition_downscale_factor + vae_downscale_factor = 2 ** (len(vae.config.block_out_channels) - 1) + compatible = self.condition_downscale_factor == vae_downscale_factor + return compatible, condition_downscale_factor, vae_downscale_factor + def forward( self, sample: torch.FloatTensor, diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 6bdc281ef8bf..bb6a9a0ba58a 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -1171,7 +1171,7 @@ def __call__( is_controlnet_compiled = is_compiled_module(self.controlnet) is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): + for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: @@ -1198,6 +1198,7 @@ def __call__( controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] + print(f'Denoising step {i} > Right before controlnet application : Device type of controlnet >> ',self.controlnet.device.type) down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, @@ -1207,6 +1208,7 @@ def __call__( guess_mode=guess_mode, return_dict=False, ) + print(f'Denoising step {i} > Right after controlnet application : Device type of controlnet >> ',self.controlnet.device.type) if guess_mode and self.do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. @@ -1240,6 +1242,7 @@ def __call__( callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] + print('btw, calling callback_on_step_end') callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) @@ -1250,6 +1253,7 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: + print('btw, calling callback') step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index 90932902a485..f905d4e1dc01 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -19,11 +19,11 @@ import PIL.Image import torch import torch.nn.functional as F -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ControlNetXSAddon, ControlNetXSModel, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, ControlNetXSAddon, ControlNetXSModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -123,8 +123,8 @@ class StableDiffusionControlNetXSPipeline( A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ - model_cpu_offload_seq = "text_encoder->controlnet->vae" - _optional_components = ["safety_checker", "feature_extractor"] + model_cpu_offload_seq = "text_encoder->image_encoder->controlnet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] def __init__( @@ -132,11 +132,11 @@ def __init__( vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - controlnet_addon: ControlNetXSAddon, + controlnet: ControlNetXSModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): super().__init__() @@ -161,7 +161,7 @@ def __init__( vae_compatible, cnxs_condition_downsample_factor, vae_downsample_factor, - ) = controlnet_addon._check_if_vae_compatible(vae) + ) = controlnet._check_if_vae_compatible(vae) if not vae_compatible: raise ValueError( f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXSAddon model ({cnxs_condition_downsample_factor}) need to be equal. Consider building the ControlNetXSAddon model with different `conditioning_embedding_out_channels`." @@ -171,13 +171,12 @@ def __init__( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, - unet=unet, - controlnet_addon=controlnet_addon, + controlnet=controlnet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + image_encoder=image_encoder, ) - self.controlnet = ControlNetXSModel(base_model=unet, ctrl_model=controlnet_addon) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( @@ -185,6 +184,27 @@ def __init__( ) self.register_to_config(requires_safety_checker=requires_safety_checker) + def from_pretrained(components_path, addon_path, components_kwargs={}, addon_kwargs={}): + """ + todo: docstring + """ + from ..stable_diffusion import StableDiffusionPipeline # todo Q: need to import here to avoid circular dependency? + + components = StableDiffusionPipeline.from_pretrained(components_path, **components_kwargs).components + controlnet_addon = ControlNetXSAddon.from_pretrained(addon_path, **addon_kwargs) + + # todo: what if StableDiffusionPipeline has more params than StableDiffusionControlNetXSPipeline + # eg if some features are not implemented in cnxs yet? + + unet = components["unet"] + components = {k:v for k,v in components.items() if k != "unet"} + + controlnet = ControlNetXSModel(unet, controlnet_addon) + return StableDiffusionControlNetXSPipeline(controlnet=controlnet, **components) + + def save_pretrained(*args, **kwargs): + raise RuntimeError("Can't save a `StableDiffusionControlNetXSPipeline`. Save the `controlnet_addon` and all other components separately.") + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" @@ -433,6 +453,31 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -687,6 +732,7 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -821,12 +867,21 @@ def __call__( lora_scale=text_encoder_lora_scale, clip_skip=clip_skip, ) + # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if ip_adapter_image is not None: + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + # 4. Prepare image if isinstance(controlnet, ControlNetXSModel): image = self.prepare_image( @@ -848,7 +903,7 @@ def __call__( timesteps = self.scheduler.timesteps # 6. Prepare latent variables - num_channels_latents = self.unet.config.in_channels + num_channels_latents = self.controlnet.base_in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, @@ -863,16 +918,18 @@ def __call__( # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - is_unet_compiled = is_compiled_module(self.unet) is_controlnet_compiled = is_compiled_module(self.controlnet) is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 - if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + if is_controlnet_compiled and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents @@ -888,6 +945,7 @@ def __call__( timestep=t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, return_dict=True, ).sample else: @@ -898,6 +956,7 @@ def __call__( controlnet_cond=image, conditioning_scale=controlnet_conditioning_scale, cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, return_dict=True, ).sample diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index fd4eb796e7ad..b67a6173eb3d 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -143,8 +143,7 @@ def __init__( text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, - unet: UNet2DConditionModel, - controlnet_addon: ControlNetXSAddon, + controlnet: ControlNetXSModel, scheduler: KarrasDiffusionSchedulers, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, @@ -155,7 +154,7 @@ def __init__( vae_compatible, cnxs_condition_downsample_factor, vae_downsample_factor, - ) = controlnet_addon._check_if_vae_compatible(vae) + ) = controlnet._check_if_vae_compatible(vae) if not vae_compatible: raise ValueError( f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXSAddon model ({cnxs_condition_downsample_factor}) need to be equal. Consider building the ControlNetXSAddon model with different `conditioning_embedding_out_channels`." @@ -167,11 +166,9 @@ def __init__( text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, - unet=unet, - controlnet_addon=controlnet_addon, + controlnet=controlnet, scheduler=scheduler, - ) - self.controlnet = ControlNetXSModel(base_model=unet, ctrl_model=controlnet_addon) + ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( @@ -186,6 +183,28 @@ def __init__( self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + def from_pretrained(components_path, addon_path, components_kwargs={}, addon_kwargs={}): + """ + todo: docstring + """ + from ..stable_diffusion import StableDiffusionXLPipeline # todo Q: need to import here to avoid circular dependency? + + components = StableDiffusionXLPipeline.from_pretrained(components_path, **components_kwargs).components + controlnet_addon = ControlNetXSAddon.from_pretrained(addon_path, **addon_kwargs) + + # todo: what if StableDiffusionXLPipeline has more params than StableDiffusionControlNetXSPipeline + # eg if some features are not implemented in cnxs yet? + + unet = components["unet"] + components = {k:v for k,v in components.items() if k != "unet"} + + controlnet = ControlNetXSModel(unet, controlnet_addon) + return StableDiffusionXLControlNetXSPipeline(controlnet=controlnet, **components) + + def save_pretrained(*args, **kwargs): + raise RuntimeError("Can't save a `StableDiffusionControlNetXSPipeline`. Save the `controlnet_addon` and all other components separately.") + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" @@ -219,7 +238,6 @@ def disable_vae_tiling(self): """ self.vae.disable_tiling() - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( self, prompt: str, @@ -415,7 +433,7 @@ def encode_prompt( if self.text_encoder_2 is not None: prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: - prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + prompt_embeds = prompt_embeds.to(dtype=self.controlnet.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method @@ -429,7 +447,7 @@ def encode_prompt( if self.text_encoder_2 is not None: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.controlnet.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -663,16 +681,15 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None ): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + self.controlnet.base_addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) - expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + expected_add_embed_dim = self.controlnet.base_add_embedding.linear_1.in_features if expected_add_embed_dim != passed_add_embed_dim: raise ValueError( @@ -723,12 +740,13 @@ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): """ if not hasattr(self, "unet"): raise ValueError("The pipeline must have `unet` for using FreeU.") - self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + # todo: check if works + self.controlnet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu def disable_freeu(self): """Disables the FreeU mechanism if enabled.""" - self.unet.disable_freeu() + # todo: check if works + self.controlnet.disable_freeu() @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) @@ -969,7 +987,7 @@ def __call__( timesteps = self.scheduler.timesteps # 6. Prepare latent variables - num_channels_latents = self.unet.config.in_channels + num_channels_latents = self.controlnet.base_in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, @@ -1027,14 +1045,13 @@ def __call__( # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - is_unet_compiled = is_compiled_module(self.unet) is_controlnet_compiled = is_compiled_module(self.controlnet) is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 - if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + if is_controlnet_compiled and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py index cdc5c6a1df09..d1b73a0d4cdc 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py @@ -24,6 +24,7 @@ from diffusers import ( AutoencoderKL, ControlNetXSAddon, + ControlNetXSModel, DDIMScheduler, LCMScheduler, StableDiffusionControlNetXSPipeline, @@ -135,6 +136,7 @@ def get_dummy_components(self, time_cond_proj_dim=None): learn_time_embedding=True, conditioning_embedding_out_channels=(16, 32), ) + controlnet = ControlNetXSModel(base_model=unet, ctrl_model=controlnet_addon) torch.manual_seed(0) scheduler = DDIMScheduler( beta_start=0.00085, @@ -169,14 +171,14 @@ def get_dummy_components(self, time_cond_proj_dim=None): tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") components = { - "unet": unet, - "controlnet_addon": controlnet_addon, + "controlnet": controlnet, "scheduler": scheduler, "vae": vae, "text_encoder": text_encoder, "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components @@ -236,6 +238,13 @@ def test_controlnet_lcm(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_save_load_local(self): + # Todo Umer: test saving controlnet addon, but not the entire pipe + pass + + def test_save_load_optional_components(self): + # Todo Umer: comment why not needed (b/c save_pretrained isn't meant to be used) + pass @slow @require_torch_gpu diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py index b27f0e88fcf2..af183066032d 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py @@ -23,6 +23,7 @@ from diffusers import ( AutoencoderKL, ControlNetXSAddon, + ControlNetXSModel, EulerDiscreteScheduler, StableDiffusionXLControlNetXSPipeline, UNet2DConditionModel, @@ -89,6 +90,7 @@ def get_dummy_components(self): learn_time_embedding=True, conditioning_embedding_out_channels=(16, 32), ) + controlnet = ControlNetXSModel(base_model=unet, ctrl_model=controlnet_addon) torch.manual_seed(0) scheduler = EulerDiscreteScheduler( beta_start=0.00085, @@ -128,8 +130,7 @@ def get_dummy_components(self): tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") components = { - "unet": unet, - "controlnet_addon": controlnet_addon, + "controlnet": controlnet, "scheduler": scheduler, "vae": vae, "text_encoder": text_encoder, @@ -307,6 +308,13 @@ def test_stable_diffusion_xl_prompt_embeds(self): # make sure that it's equal assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1.1e-4 + def test_save_load_local(self): + # Todo Umer: test saving controlnet addon, but not the entire pipe + pass + + def test_save_load_optional_components(self): + # Todo Umer: comment why not needed (b/c save_pretrained isn't meant to be used) + pass @slow @require_torch_gpu From a709ef8d46b658f877178bbed8853b05266c69b9 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Fri, 26 Jan 2024 06:26:26 +0100 Subject: [PATCH 28/75] Made unet & addon explicit in controlnet --- src/diffusers/models/controlnet_xs.py | 147 ++++++++---------- .../controlnet_xs/pipeline_controlnet_xs.py | 37 ++--- .../pipeline_controlnet_xs_sd_xl.py | 41 ++--- 3 files changed, 92 insertions(+), 133 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index d0839bcbcc61..ae5debc15ec8 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -428,18 +428,11 @@ def forward(self, *args, **kwargs): "A ControlNetXSAddonModel cannot be run by itself. Pass it into a ControlNetXSModel model instead." ) - @torch.no_grad() - def _check_if_vae_compatible(self, vae: AutoencoderKL): - condition_downscale_factor = 2 ** (len(self.config.conditioning_embedding_out_channels) - 1) - vae_downscale_factor = 2 ** (len(vae.config.block_out_channels) - 1) - compatible = condition_downscale_factor == vae_downscale_factor - return compatible, condition_downscale_factor, vae_downscale_factor - def _make_zero_conv(self, in_channels, out_channels=None): return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0)) -class ControlNetXSModel(ModelMixin, ConfigMixin): +class ControlNetXSModel(nn.Module): r""" A ControlNet-XS model @@ -511,7 +504,6 @@ def get_dim_attn_heads(base_model: UNet2DConditionModel, size_ratio: float, num_ return cls(base_model=base_model, ctrl_model=controlnet_addon, time_embedding_mix=time_embedding_mix) - @register_to_config def __init__( self, base_model: UNet2DConditionModel, @@ -520,46 +512,14 @@ def __init__( ): super().__init__() - self.condition_downscale_factor = 2 ** (len(ctrl_model.config.conditioning_embedding_out_channels) - 1) - - # 1 - Save options - self.class_embed_type = base_model.config.class_embed_type - self.use_ctrl_time_embedding = ctrl_model.config.learn_time_embedding - self.conditioning_channel_order = ctrl_model.config.conditioning_channel_order - - # 2 - Save control model parts - self.ctrl_time_embedding = ctrl_model.time_embedding - self.ctrl_conv_in = ctrl_model.conv_in - self.ctrl_controlnet_cond_embedding = ctrl_model.controlnet_cond_embedding - self.ctrl_down_subblocks = ctrl_model.down_subblocks - self.ctrl_mid_block = ctrl_model.mid_block - - # 3 - Save connections - self.down_zero_convs_b2c = ctrl_model.down_zero_convs_b2c - self.down_zero_convs_c2b = ctrl_model.down_zero_convs_c2b - self.mid_zero_convs_c2b = ctrl_model.mid_zero_convs_c2b - self.up_zero_convs_c2b = ctrl_model.up_zero_convs_c2b - - # 4 - Save base model parts - self.base_in_channels = base_model.config.in_channels - self.base_time_proj = base_model.time_proj - self.base_time_embedding = base_model.time_embedding - self.base_class_embedding = base_model.class_embedding - self.base_addition_embed_type = base_model.config.addition_embed_type - self.base_conv_in = base_model.conv_in + self.ctrl_model = ctrl_model + self.base_model = base_model + self.time_embedding_mix = time_embedding_mix + + # Decompose blocks of base model into subblocks self.base_down_subblocks = nn.ModuleList() - self.base_mid_block = base_model.mid_block self.base_up_subblocks = nn.ModuleList() - # 4.1 - SDXL specific components - if hasattr(base_model, "add_time_proj"): - self.base_add_time_proj = base_model.add_time_proj - if hasattr(base_model, "add_embedding"): - self.base_add_embedding = base_model.add_embedding - if hasattr(base_model.config, "addition_time_embed_dim"): - self.base_addition_time_embed_dim = base_model.config.addition_time_embed_dim - - # 4.2 - Decompose blocks of base model into subblocks for block in base_model.down_blocks: # Each ResNet / Attention pair is a subblock resnets = block.resnets @@ -593,17 +553,11 @@ def __init__( for r, a, u in zip(resnets, attentions, upsamplers): self.base_up_subblocks.append(CrossAttnUpSubBlock2D.from_modules(r, a, u)) - self.base_conv_norm_out = base_model.conv_norm_out - self.base_conv_act = base_model.conv_act - self.base_conv_out = base_model.conv_out - - self.time_embedding_mix = time_embedding_mix - @torch.no_grad() def _check_if_vae_compatible(self, vae: AutoencoderKL): - condition_downscale_factor = self.condition_downscale_factor + condition_downscale_factor = 2 ** (len(self.ctrl_model.config.conditioning_embedding_out_channels) - 1) vae_downscale_factor = 2 ** (len(vae.config.block_out_channels) - 1) - compatible = self.condition_downscale_factor == vae_downscale_factor + compatible = condition_downscale_factor == vae_downscale_factor return compatible, condition_downscale_factor, vae_downscale_factor def forward( @@ -619,6 +573,7 @@ def forward( cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, return_dict: bool = True, + do_control: bool = True, ) -> Union[ControlNetXSOutput, Tuple]: """ The [`ControlNetModel`] forward method. @@ -659,8 +614,21 @@ def forward( tuple is returned where the first element is the sample tensor. """ + if not do_control: + return self.base_model( + sample=sample, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + class_labels=class_labels, + timestep_cond=timestep_cond, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=return_dict + ) + # check channel order - if self.conditioning_channel_order == "bgr": + if self.ctrl_model.config.conditioning_channel_order == "bgr": controlnet_cond = torch.flip(controlnet_cond, dims=[1]) # prepare attention_mask @@ -685,38 +653,38 @@ def forward( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) - t_emb = self.base_time_proj(timesteps) + t_emb = self.base_model.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) - if self.use_ctrl_time_embedding: - ctrl_temb = self.ctrl_time_embedding(t_emb, timestep_cond) - base_temb = self.base_time_embedding(t_emb, timestep_cond) - interpolation_param = self.config.time_embedding_mix**0.3 + if self.ctrl_model.config.learn_time_embedding: + ctrl_temb = self.ctrl_model.time_embedding(t_emb, timestep_cond) + base_temb = self.base_model.time_embedding(t_emb, timestep_cond) + interpolation_param = self.time_embedding_mix**0.3 temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param) else: - temb = self.base_time_embedding(t_emb) + temb = self.base_model.time_embedding(t_emb) # added time & text embeddings aug_emb = None - if self.base_class_embedding is not None: + if self.base_model.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") - if self.class_embed_type == "timestep": + if self.base_model.config.class_embed_type == "timestep": class_labels = self.base_time_proj(class_labels) - class_emb = self.base_class_embedding(class_labels).to(dtype=self.dtype) + class_emb = self.base_model.class_embedding(class_labels).to(dtype=self.dtype) temb = temb + class_emb - if self.base_addition_embed_type is None: + if self.base_model.config.addition_embed_type is None: pass - elif self.base_addition_embed_type == "text_time": + elif self.base_model.config.addition_embed_type == "text_time": # SDXL - style if "text_embeds" not in added_cond_kwargs: raise ValueError( @@ -728,14 +696,14 @@ def forward( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") - time_embeds = self.base_add_time_proj(time_ids.flatten()) + time_embeds = self.base_model.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(temb.dtype) - aug_emb = self.base_add_embedding(add_embeds) + aug_emb = self.base_model.add_embedding(add_embeds) else: raise ValueError( - f"ControlNet-XS currently only supports StableDiffusion and StableDiffusion-XL, so addition_embed_type = {self.base_addition_embed_type} is currently not supported." + f"ControlNet-XS currently only supports StableDiffusion and StableDiffusion-XL, so addition_embed_type = {self.base_model.config.addition_embed_type} is currently not supported." ) temb = temb + aug_emb if aug_emb is not None else temb @@ -744,32 +712,41 @@ def forward( cemb = encoder_hidden_states # Preparation - guided_hint = self.ctrl_controlnet_cond_embedding(controlnet_cond) + guided_hint = self.ctrl_model.controlnet_cond_embedding(controlnet_cond) h_ctrl = h_base = sample hs_base, hs_ctrl = [], [] # Cross Control + # Let's first define variables to shorten notation + base_down_subblocks = self.base_down_subblocks + ctrl_down_subblocks = self.ctrl_model.down_subblocks + + down_zero_convs_b2c = self.ctrl_model.down_zero_convs_b2c + down_zero_convs_c2b = self.ctrl_model.down_zero_convs_c2b + mid_zero_convs_c2b = self.ctrl_model.mid_zero_convs_c2b + up_zero_convs_c2b = self.ctrl_model.up_zero_convs_c2b + # 1 - conv in & down # The base -> ctrl connections are "delayed" by 1 subblock, because we want to "wait" to ensure the new information from the last ctrl -> base connection is also considered # Therefore, the connections iterate over: # ctrl -> base: conv_in | subblock 1 | ... | subblock n # base -> ctrl: | subblock 1 | ... | subblock n | mid block - h_base = self.base_conv_in(h_base) - h_ctrl = self.ctrl_conv_in(h_ctrl) + h_base = self.base_model.conv_in(h_base) + h_ctrl = self.ctrl_model.conv_in(h_ctrl) if guided_hint is not None: h_ctrl += guided_hint - h_base = h_base + self.down_zero_convs_c2b[0](h_ctrl) * conditioning_scale # add ctrl -> base + h_base = h_base + down_zero_convs_c2b[0](h_ctrl) * conditioning_scale # add ctrl -> base hs_base.append(h_base) hs_ctrl.append(h_ctrl) for b, c, b2c, c2b in zip( - self.base_down_subblocks, - self.ctrl_down_subblocks, - self.down_zero_convs_b2c[:-1], - self.down_zero_convs_c2b[1:], + base_down_subblocks, + ctrl_down_subblocks, + down_zero_convs_b2c[:-1], + down_zero_convs_c2b[1:], ): if isinstance(b, CrossAttnSubBlock2D): additional_params = [temb, cemb, attention_mask, cross_attention_kwargs] @@ -783,24 +760,24 @@ def forward( hs_base.append(h_base) hs_ctrl.append(h_ctrl) - h_ctrl = torch.cat([h_ctrl, self.down_zero_convs_b2c[-1](h_base)], dim=1) # concat base -> ctrl + h_ctrl = torch.cat([h_ctrl, down_zero_convs_b2c[-1](h_base)], dim=1) # concat base -> ctrl # 2 - mid - h_base = self.base_mid_block(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # apply base subblock - h_ctrl = self.ctrl_mid_block(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # apply ctrl subblock - h_base = h_base + self.mid_zero_convs_c2b(h_ctrl) * conditioning_scale # add ctrl -> base + h_base = self.base_model.mid_block(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # apply base subblock + h_ctrl = self.ctrl_model.mid_block(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # apply ctrl subblock + h_base = h_base + mid_zero_convs_c2b(h_ctrl) * conditioning_scale # add ctrl -> base # 3 - up for b, c2b, skip_c, skip_b in zip( - self.base_up_subblocks, self.up_zero_convs_c2b, reversed(hs_ctrl), reversed(hs_base) + self.base_up_subblocks, up_zero_convs_c2b, reversed(hs_ctrl), reversed(hs_base) ): h_base = h_base + c2b(skip_c) * conditioning_scale # add info from ctrl encoder h_base = torch.cat([h_base, skip_b], dim=1) # concat info from base encoder+ctrl encoder h_base = b(h_base, temb, cemb, attention_mask, cross_attention_kwargs) - h_base = self.base_conv_norm_out(h_base) - h_base = self.base_conv_act(h_base) - h_base = self.base_conv_out(h_base) + h_base = self.base_model.conv_norm_out(h_base) + h_base = self.base_model.conv_act(h_base) + h_base = self.base_model.conv_out(h_base) if not return_dict: return h_base diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index f905d4e1dc01..f545f5c1f4f5 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -903,7 +903,7 @@ def __call__( timesteps = self.scheduler.timesteps # 6. Prepare latent variables - num_channels_latents = self.controlnet.base_in_channels + num_channels_latents = self.controlnet.base_model.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, @@ -936,29 +936,20 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - dont_control = ( - i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end + do_control = ( + i / len(timesteps) >= control_guidance_start and (i + 1) / len(timesteps) <= control_guidance_end ) - if dont_control: - noise_pred = self.unet( - sample=latent_model_input, - timestep=t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - return_dict=True, - ).sample - else: - noise_pred = self.controlnet( - sample=latent_model_input, - timestep=t, - encoder_hidden_states=prompt_embeds, - controlnet_cond=image, - conditioning_scale=controlnet_conditioning_scale, - cross_attention_kwargs=cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - return_dict=True, - ).sample + noise_pred = self.controlnet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=prompt_embeds, + controlnet_cond=image, + conditioning_scale=controlnet_conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=True, + do_control=do_control, + ).sample # perform guidance if do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index b67a6173eb3d..cac9035d2c61 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -687,9 +687,9 @@ def _get_add_time_ids( add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( - self.controlnet.base_addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + self.controlnet.base_model.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) - expected_add_embed_dim = self.controlnet.base_add_embedding.linear_1.in_features + expected_add_embed_dim = self.controlnet.base_model.add_embedding.linear_1.in_features if expected_add_embed_dim != passed_add_embed_dim: raise ValueError( @@ -987,7 +987,7 @@ def __call__( timesteps = self.scheduler.timesteps # 6. Prepare latent variables - num_channels_latents = self.controlnet.base_in_channels + num_channels_latents = self.controlnet.base_model.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, @@ -1060,29 +1060,20 @@ def __call__( added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} # predict the noise residual - dont_control = ( - i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end + do_control = ( + i / len(timesteps) >= control_guidance_start and (i + 1) / len(timesteps) <= control_guidance_end ) - if dont_control: - noise_pred = self.unet( - sample=latent_model_input, - timestep=t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - return_dict=True, - ).sample - else: - noise_pred = self.controlnet( - sample=latent_model_input, - timestep=t, - encoder_hidden_states=prompt_embeds, - controlnet_cond=image, - conditioning_scale=controlnet_conditioning_scale, - cross_attention_kwargs=cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - return_dict=True, - ).sample + noise_pred = self.controlnet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=prompt_embeds, + controlnet_cond=image, + conditioning_scale=controlnet_conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=True, + do_control=do_control, + ).sample # perform guidance if do_classifier_free_guidance: From eecf9a5f92189b9f99a2d62cb71e7ca635b5ae68 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Fri, 26 Jan 2024 06:35:40 +0100 Subject: [PATCH 29/75] Updated slow tests --- tests/pipelines/controlnet_xs/test_controlnetxs.py | 10 ++++------ .../pipelines/controlnet_xs/test_controlnetxs_sdxl.py | 10 ++++------ 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py index d1b73a0d4cdc..892c8589477a 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py @@ -255,10 +255,9 @@ def tearDown(self): torch.cuda.empty_cache() def test_canny(self): - controlnet_addon = ControlNetXSAddon.from_pretrained("UmerHA/Testing-ConrolNetXS-SD2.1-canny") - pipe = StableDiffusionControlNetXSPipeline.from_pretrained( - "stabilityai/stable-diffusion-2-1", safety_checker=None, controlnet_addon=controlnet_addon + components_path="stabilityai/stable-diffusion-2-1", + addon_path="UmerHA/Testing-ConrolNetXS-SD2.1-canny", ) pipe.enable_model_cpu_offload() pipe.set_progress_bar_config(disable=None) @@ -280,10 +279,9 @@ def test_canny(self): assert np.allclose(original_image, expected_image, atol=1e-04) def test_depth(self): - controlnet_addon = ControlNetXSAddon.from_pretrained("todo umer") - pipe = StableDiffusionControlNetXSPipeline.from_pretrained( - "stabilityai/stable-diffusion-2-1", safety_checker=None, controlnet_addon=controlnet_addon + components_path="stabilityai/stable-diffusion-2-1", + addon_path="todo umer", ) pipe.enable_model_cpu_offload() pipe.set_progress_bar_config(disable=None) diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py index af183066032d..fe14b324ee9f 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py @@ -325,10 +325,9 @@ def tearDown(self): torch.cuda.empty_cache() def test_canny(self): - controlnet_addon = ControlNetXSAddon.from_pretrained("UmerHA/Testing-ConrolNetXS-SDXL-canny") - pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", controlnet_addon=controlnet_addon + components_path="stabilityai/stable-diffusion-xl-base-1.0", + addon_path="UmerHA/Testing-ConrolNetXS-SDXL-canny" ) pipe.enable_sequential_cpu_offload() pipe.set_progress_bar_config(disable=None) @@ -348,10 +347,9 @@ def test_canny(self): assert np.allclose(original_image, expected_image, atol=1e-04) def test_depth(self): - controlnet_addon = ControlNetXSAddon.from_pretrained("todo umer") - pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", controlnet_addon=controlnet_addon + components_path="stabilityai/stable-diffusion-xl-base-1.0", + addon_path="todo umer" ) pipe.enable_sequential_cpu_offload() pipe.set_progress_bar_config(disable=None) From 15affe31628dda8cdf57c7b25f8e6f4cec68d2cc Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Fri, 26 Jan 2024 06:49:19 +0100 Subject: [PATCH 30/75] Added dtype/device to ControlNetXS --- src/diffusers/models/controlnet_xs.py | 15 +++++++++++++++ .../pipelines/controlnet_xs/test_controlnetxs.py | 2 +- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index ae5debc15ec8..c2c475aa767a 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -553,6 +553,21 @@ def __init__( for r, a, u in zip(resnets, attentions, upsamplers): self.base_up_subblocks.append(CrossAttnUpSubBlock2D.from_modules(r, a, u)) + @property + def device(self) -> torch.device: + """ + `torch.device`: The device on which the module is (assuming that all the module parameters are on the same + device). + """ + return self.base_model.device + + @property + def dtype(self) -> torch.dtype: + """ + `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + """ + return self.base_model.dtype + @torch.no_grad() def _check_if_vae_compatible(self, vae: AutoencoderKL): condition_downscale_factor = 2 ** (len(self.ctrl_model.config.conditioning_embedding_out_channels) - 1) diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py index 892c8589477a..871f4093f497 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py @@ -273,7 +273,7 @@ def test_canny(self): image = output.images[0] assert image.shape == (768, 512, 3) - + original_image = image[-3:, -3:, -1].flatten() expected_image = np.array([0.1274, 0.1401, 0.147, 0.1185, 0.1555, 0.1492, 0.1565, 0.1474, 0.1701]) assert np.allclose(original_image, expected_image, atol=1e-04) From 08c8e3143768d3cecb36f2f80df80b6249a65959 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Fri, 26 Jan 2024 07:26:56 +0100 Subject: [PATCH 31/75] Filled in test model paths --- tests/pipelines/controlnet_xs/test_controlnetxs.py | 2 +- tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py index 871f4093f497..6c3bff44bdca 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py @@ -281,7 +281,7 @@ def test_canny(self): def test_depth(self): pipe = StableDiffusionControlNetXSPipeline.from_pretrained( components_path="stabilityai/stable-diffusion-2-1", - addon_path="todo umer", + addon_path="UmerHA/Testing-ConrolNetXS-SD2.1-depth", ) pipe.enable_model_cpu_offload() pipe.set_progress_bar_config(disable=None) diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py index fe14b324ee9f..f490e1cfc9d4 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py @@ -349,7 +349,7 @@ def test_canny(self): def test_depth(self): pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( components_path="stabilityai/stable-diffusion-xl-base-1.0", - addon_path="todo umer" + addon_path="UmerHA/Testing-ConrolNetXS-SDXL-depth" ) pipe.enable_sequential_cpu_offload() pipe.set_progress_bar_config(disable=None) From 88285866fb2c5faaf96f4970f45fc29417726c12 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Fri, 26 Jan 2024 07:54:21 +0100 Subject: [PATCH 32/75] Added image_encoder/feature_extractor to XL pipe --- .../controlnet_xs/pipeline_controlnet_xs.py | 3 +- .../pipeline_controlnet_xs_sd_xl.py | 53 ++++++++++++++++--- 2 files changed, 48 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index f545f5c1f4f5..116233c88417 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -36,6 +36,7 @@ ) from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion import StableDiffusionPipeline from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -188,8 +189,6 @@ def from_pretrained(components_path, addon_path, components_kwargs={}, addon_kwa """ todo: docstring """ - from ..stable_diffusion import StableDiffusionPipeline # todo Q: need to import here to avoid circular dependency? - components = StableDiffusionPipeline.from_pretrained(components_path, **components_kwargs).components controlnet_addon = ControlNetXSAddon.from_pretrained(addon_path, **addon_kwargs) diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index cac9035d2c61..44c63e5ecb42 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -19,7 +19,13 @@ import PIL.Image import torch import torch.nn.functional as F -from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) from diffusers.utils.import_utils import is_invisible_watermark_available @@ -37,6 +43,7 @@ from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion_xl import StableDiffusionXLPipeline from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput @@ -133,9 +140,15 @@ class StableDiffusionXLControlNetXSPipeline( watermarker is used. """ - model_cpu_offload_seq = "text_encoder->text_encoder_2->controlnet->vae" - _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] - + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->controlnet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + "image_encoder", + ] def __init__( self, vae: AutoencoderKL, @@ -147,6 +160,8 @@ def __init__( scheduler: KarrasDiffusionSchedulers, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, ): super().__init__() @@ -168,6 +183,8 @@ def __init__( tokenizer_2=tokenizer_2, controlnet=controlnet, scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) @@ -187,8 +204,6 @@ def from_pretrained(components_path, addon_path, components_kwargs={}, addon_kwa """ todo: docstring """ - from ..stable_diffusion import StableDiffusionXLPipeline # todo Q: need to import here to avoid circular dependency? - components = StableDiffusionXLPipeline.from_pretrained(components_path, **components_kwargs).components controlnet_addon = ControlNetXSAddon.from_pretrained(addon_path, **addon_kwargs) @@ -238,6 +253,7 @@ def disable_vae_tiling(self): """ self.vae.disable_tiling() + # todo: check if copy def encode_prompt( self, prompt: str, @@ -472,6 +488,31 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature From 72fce3e12e9c57bdc4727f1caea28c0840b4d206 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Tue, 30 Jan 2024 11:40:05 +0100 Subject: [PATCH 33/75] Fixed fast tests --- src/diffusers/models/controlnet_xs.py | 3 +- .../controlnet_xs/pipeline_controlnet_xs.py | 94 ++++++++++-- .../pipeline_controlnet_xs_sd_xl.py | 96 ++++++++++-- .../controlnet_xs/test_controlnetxs.py | 110 ++++++++++++-- .../controlnet_xs/test_controlnetxs_sdxl.py | 141 ++++++++++++++++-- 5 files changed, 401 insertions(+), 43 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index c2c475aa767a..b874bd2c5f80 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -289,11 +289,12 @@ def __init__( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) - # todo: attention_head_dim can be int, not list(int) if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." ) + elif isinstance(attention_head_dim, int): + attention_head_dim = [attention_head_dim] * len(down_block_types) # input self.conv_in = nn.Conv2d(4, block_out_channels[0], kernel_size=3, padding=1) diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index 116233c88417..080d5d504d26 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -127,6 +127,7 @@ class StableDiffusionControlNetXSPipeline( model_cpu_offload_seq = "text_encoder->image_encoder->controlnet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, @@ -185,11 +186,12 @@ def __init__( ) self.register_to_config(requires_safety_checker=requires_safety_checker) - def from_pretrained(components_path, addon_path, components_kwargs={}, addon_kwargs={}): + @classmethod + def from_pretrained(cls, base_path, addon_path, base_kwargs={}, addon_kwargs={}): """ todo: docstring """ - components = StableDiffusionPipeline.from_pretrained(components_path, **components_kwargs).components + components = StableDiffusionPipeline.from_pretrained(base_path, **base_kwargs).components controlnet_addon = ControlNetXSAddon.from_pretrained(addon_path, **addon_kwargs) # todo: what if StableDiffusionPipeline has more params than StableDiffusionControlNetXSPipeline @@ -201,8 +203,15 @@ def from_pretrained(components_path, addon_path, components_kwargs={}, addon_kwa controlnet = ControlNetXSModel(unet, controlnet_addon) return StableDiffusionControlNetXSPipeline(controlnet=controlnet, **components) - def save_pretrained(*args, **kwargs): - raise RuntimeError("Can't save a `StableDiffusionControlNetXSPipeline`. Save the `controlnet_addon` and all other components separately.") + def save_pretrained(self, base_path, addon_path, base_kwargs={}, addon_kwargs={}): + """todo docs""" + components = {k:v for k,v in self.components.items() if k!="controlnet"} + components["unet"] = self.components["controlnet"].base_model + + controlnet_addon = self.components["controlnet"].ctrl_model + + StableDiffusionPipeline(**components).save_pretrained(base_path, **base_kwargs) + controlnet_addon.save_pretrained(addon_path, **addon_kwargs) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): @@ -533,14 +542,19 @@ def check_inputs( controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, ): - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) if prompt is not None and prompt_embeds is not None: raise ValueError( @@ -686,7 +700,6 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. @@ -709,11 +722,35 @@ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): raise ValueError("The pipeline must have `unet` for using FreeU.") self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu def disable_freeu(self): """Disables the FreeU mechanism if enabled.""" self.unet.disable_freeu() + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale + @property + def guidance_scale(self): + return self._guidance_scale + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip + @property + def clip_skip(self): + return self._clip_skip + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps + @property + def num_timesteps(self): + return self._num_timesteps + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -734,13 +771,14 @@ def __call__( ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, control_guidance_start: float = 0.0, control_guidance_end: float = 1.0, clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, ): r""" The call function to the pipeline for generation. @@ -822,6 +860,23 @@ def __call__( second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # 1. Check inputs. Raise error if not correct @@ -835,8 +890,14 @@ def __call__( controlnet_conditioning_scale, control_guidance_start, control_guidance_end, + callback_on_step_end_tensor_inputs, ) + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -922,6 +983,7 @@ def __call__( # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) is_controlnet_compiled = is_compiled_module(self.controlnet) is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -957,6 +1019,16 @@ def __call__( latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index 44c63e5ecb42..2789cbe77fcb 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -40,7 +40,7 @@ ) from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers -from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion_xl import StableDiffusionXLPipeline @@ -149,6 +149,8 @@ class StableDiffusionXLControlNetXSPipeline( "feature_extractor", "image_encoder", ] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + def __init__( self, vae: AutoencoderKL, @@ -200,11 +202,12 @@ def __init__( self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) - def from_pretrained(components_path, addon_path, components_kwargs={}, addon_kwargs={}): + @classmethod + def from_pretrained(cls, base_path, addon_path, base_kwargs={}, addon_kwargs={}): """ todo: docstring """ - components = StableDiffusionXLPipeline.from_pretrained(components_path, **components_kwargs).components + components = StableDiffusionXLPipeline.from_pretrained(base_path, **base_kwargs).components controlnet_addon = ControlNetXSAddon.from_pretrained(addon_path, **addon_kwargs) # todo: what if StableDiffusionXLPipeline has more params than StableDiffusionControlNetXSPipeline @@ -216,9 +219,15 @@ def from_pretrained(components_path, addon_path, components_kwargs={}, addon_kwa controlnet = ControlNetXSModel(unet, controlnet_addon) return StableDiffusionXLControlNetXSPipeline(controlnet=controlnet, **components) - def save_pretrained(*args, **kwargs): - raise RuntimeError("Can't save a `StableDiffusionControlNetXSPipeline`. Save the `controlnet_addon` and all other components separately.") + def save_pretrained(self, base_path, addon_path, base_kwargs={}, addon_kwargs={}): + """todo docs""" + components = {k:v for k,v in self.components.items() if k!="controlnet"} + components["unet"] = self.components["controlnet"].base_model + + controlnet_addon = self.components["controlnet"].ctrl_model + StableDiffusionXLPipeline(**components).save_pretrained(base_path, **base_kwargs) + controlnet_addon.save_pretrained(addon_path, **addon_kwargs) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): @@ -546,15 +555,21 @@ def check_inputs( controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, ): - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" @@ -760,7 +775,6 @@ def upcast_vae(self): self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. @@ -789,6 +803,31 @@ def disable_freeu(self): # todo: check if works self.controlnet.disable_freeu() + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale + @property + def guidance_scale(self): + return self._guidance_scale + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip + @property + def clip_skip(self): + return self._clip_skip + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps + @property + def num_timesteps(self): + return self._num_timesteps + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -812,8 +851,6 @@ def __call__( negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, control_guidance_start: float = 0.0, @@ -825,6 +862,9 @@ def __call__( negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, ): r""" The call function to the pipeline for generation. @@ -949,6 +989,23 @@ def __call__( If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] is returned, otherwise a `tuple` is returned containing the output images. """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # 1. Check inputs. Raise error if not correct @@ -966,8 +1023,14 @@ def __call__( controlnet_conditioning_scale, control_guidance_start, control_guidance_end, + callback_on_step_end_tensor_inputs, ) + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -1086,6 +1149,7 @@ def __call__( # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) is_controlnet_compiled = is_compiled_module(self.controlnet) is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -1124,6 +1188,16 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py index 6c3bff44bdca..98b4a1bf8b40 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py @@ -16,11 +16,13 @@ import gc import traceback import unittest +import tempfile import numpy as np import torch from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer +import diffusers from diffusers import ( AutoencoderKL, ControlNetXSAddon, @@ -30,6 +32,7 @@ StableDiffusionControlNetXSPipeline, UNet2DConditionModel, ) +from diffusers.utils import logging from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( enable_full_determinism, @@ -54,6 +57,7 @@ PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, + to_np ) @@ -238,13 +242,99 @@ def test_controlnet_lcm(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - def test_save_load_local(self): - # Todo Umer: test saving controlnet addon, but not the entire pipe - pass + def test_save_load_local(self, expected_max_difference=5e-4): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0] + + logger = logging.get_logger("diffusers.pipelines.pipeline_utils") + logger.setLevel(diffusers.logging.INFO) + + with tempfile.TemporaryDirectory() as tmpdir_components: + with tempfile.TemporaryDirectory() as tmpdir_addon: + pipe.save_pretrained( + base_path=tmpdir_components, + addon_path=tmpdir_addon, + base_kwargs=dict(safe_serialization=False), + addon_kwargs=dict(safe_serialization=False), + ) + + pipe_loaded = self.pipeline_class.from_pretrained( + base_path=tmpdir_components, + addon_path=tmpdir_addon + ) + + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() + self.assertLess(max_diff, expected_max_difference) + + def test_save_load_optional_components(self, expected_max_difference=1e-4): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + # set all optional components to None + for optional_component in pipe._optional_components: + setattr(pipe, optional_component, None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir_components: + with tempfile.TemporaryDirectory() as tmpdir_addon: + + pipe.save_pretrained( + base_path=tmpdir_components, + addon_path=tmpdir_addon, + base_kwargs=dict(safe_serialization=False), + addon_kwargs=dict(safe_serialization=False), + ) + + pipe_loaded = self.pipeline_class.from_pretrained( + base_path=tmpdir_components, + addon_path=tmpdir_addon + ) + + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + for optional_component in pipe._optional_components: + self.assertTrue( + getattr(pipe_loaded, optional_component) is None, + f"`{optional_component}` did not stay set to None after loading.", + ) + + inputs = self.get_dummy_inputs(generator_device) + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() + self.assertLess(max_diff, expected_max_difference) - def test_save_load_optional_components(self): - # Todo Umer: comment why not needed (b/c save_pretrained isn't meant to be used) - pass @slow @require_torch_gpu @@ -256,7 +346,7 @@ def tearDown(self): def test_canny(self): pipe = StableDiffusionControlNetXSPipeline.from_pretrained( - components_path="stabilityai/stable-diffusion-2-1", + base_path="stabilityai/stable-diffusion-2-1", addon_path="UmerHA/Testing-ConrolNetXS-SD2.1-canny", ) pipe.enable_model_cpu_offload() @@ -275,12 +365,12 @@ def test_canny(self): assert image.shape == (768, 512, 3) original_image = image[-3:, -3:, -1].flatten() - expected_image = np.array([0.1274, 0.1401, 0.147, 0.1185, 0.1555, 0.1492, 0.1565, 0.1474, 0.1701]) + expected_image = np.array([0.1462, 0.1518, 0.1583, 0.1332, 0.1655, 0.1629, 0.1646, 0.1595, 0.1762]) assert np.allclose(original_image, expected_image, atol=1e-04) def test_depth(self): pipe = StableDiffusionControlNetXSPipeline.from_pretrained( - components_path="stabilityai/stable-diffusion-2-1", + base_path="stabilityai/stable-diffusion-2-1", addon_path="UmerHA/Testing-ConrolNetXS-SD2.1-depth", ) pipe.enable_model_cpu_offload() @@ -299,7 +389,7 @@ def test_depth(self): assert image.shape == (512, 512, 3) original_image = image[-3:, -3:, -1].flatten() - expected_image = np.array([0.1098, 0.1025, 0.1211, 0.1129, 0.1165, 0.1262, 0.1185, 0.1261, 0.1703]) + expected_image = np.array([0.1504, 0.1448, 0.1742, 0.155 , 0.1553, 0.1833, 0.1694, 0.1833, 0.2354]) assert np.allclose(original_image, expected_image, atol=1e-04) @require_python39_or_higher diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py index f490e1cfc9d4..c854561d78d1 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py @@ -15,11 +15,13 @@ import gc import unittest +import tempfile import numpy as np import torch from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +import diffusers from diffusers import ( AutoencoderKL, ControlNetXSAddon, @@ -28,6 +30,7 @@ StableDiffusionXLControlNetXSPipeline, UNet2DConditionModel, ) +from diffusers.utils import logging from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import enable_full_determinism, load_image, require_torch_gpu, slow, torch_device from diffusers.utils.torch_utils import randn_tensor @@ -43,6 +46,7 @@ PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, + to_np ) @@ -137,6 +141,8 @@ def get_dummy_components(self): "tokenizer": tokenizer, "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, + "feature_extractor": None, + "image_encoder": None, } return components @@ -308,13 +314,128 @@ def test_stable_diffusion_xl_prompt_embeds(self): # make sure that it's equal assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1.1e-4 - def test_save_load_local(self): - # Todo Umer: test saving controlnet addon, but not the entire pipe - pass + # copied from test_controlnetxs.py + def test_save_load_local(self, expected_max_difference=5e-4): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() - def test_save_load_optional_components(self): - # Todo Umer: comment why not needed (b/c save_pretrained isn't meant to be used) - pass + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0] + + logger = logging.get_logger("diffusers.pipelines.pipeline_utils") + logger.setLevel(diffusers.logging.INFO) + + with tempfile.TemporaryDirectory() as tmpdir_components: + with tempfile.TemporaryDirectory() as tmpdir_addon: + pipe.save_pretrained( + base_path=tmpdir_components, + addon_path=tmpdir_addon, + base_kwargs=dict(safe_serialization=False), + addon_kwargs=dict(safe_serialization=False), + ) + + pipe_loaded = self.pipeline_class.from_pretrained( + base_path=tmpdir_components, + addon_path=tmpdir_addon + ) + + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() + self.assertLess(max_diff, expected_max_difference) + + def test_save_load_optional_components(self, expected_max_difference=1e-4): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + + # set all optional components to None + for optional_component in pipe._optional_components: + setattr(pipe, optional_component, None) + + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + + tokenizer = components.pop("tokenizer") + tokenizer_2 = components.pop("tokenizer_2") + text_encoder = components.pop("text_encoder") + text_encoder_2 = components.pop("text_encoder_2") + + tokenizers = [tokenizer, tokenizer_2] if tokenizer is not None else [tokenizer_2] + text_encoders = [text_encoder, text_encoder_2] if text_encoder is not None else [text_encoder_2] + prompt = inputs.pop("prompt") + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt(tokenizers, text_encoders, prompt) + inputs["prompt_embeds"] = prompt_embeds + inputs["negative_prompt_embeds"] = negative_prompt_embeds + inputs["pooled_prompt_embeds"] = pooled_prompt_embeds + inputs["negative_pooled_prompt_embeds"] = negative_pooled_prompt_embeds + + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir_components: + with tempfile.TemporaryDirectory() as tmpdir_addon: + + pipe.save_pretrained( + base_path=tmpdir_components, + addon_path=tmpdir_addon, + base_kwargs=dict(safe_serialization=False), + addon_kwargs=dict(safe_serialization=False), + ) + + pipe_loaded = self.pipeline_class.from_pretrained( + base_path=tmpdir_components, + addon_path=tmpdir_addon + ) + + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + for optional_component in pipe._optional_components: + self.assertTrue( + getattr(pipe_loaded, optional_component) is None, + f"`{optional_component}` did not stay set to None after loading.", + ) + + inputs = self.get_dummy_inputs(generator_device) + + _ = inputs.pop("prompt") + inputs["prompt_embeds"] = prompt_embeds + inputs["negative_prompt_embeds"] = negative_prompt_embeds + inputs["pooled_prompt_embeds"] = pooled_prompt_embeds + inputs["negative_pooled_prompt_embeds"] = negative_pooled_prompt_embeds + + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() + self.assertLess(max_diff, expected_max_difference) @slow @require_torch_gpu @@ -326,7 +447,7 @@ def tearDown(self): def test_canny(self): pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( - components_path="stabilityai/stable-diffusion-xl-base-1.0", + base_path="stabilityai/stable-diffusion-xl-base-1.0", addon_path="UmerHA/Testing-ConrolNetXS-SDXL-canny" ) pipe.enable_sequential_cpu_offload() @@ -343,12 +464,12 @@ def test_canny(self): assert images[0].shape == (768, 512, 3) original_image = images[0, -3:, -3:, -1].flatten() - expected_image = np.array([0.4359, 0.4335, 0.4609, 0.4515, 0.4669, 0.4494, 0.452, 0.4493, 0.4382]) + expected_image = np.array([0.4371, 0.4341, 0.4620, 0.4524, 0.4680, 0.4504, 0.4530, 0.4505, 0.4390]) assert np.allclose(original_image, expected_image, atol=1e-04) def test_depth(self): pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( - components_path="stabilityai/stable-diffusion-xl-base-1.0", + base_path="stabilityai/stable-diffusion-xl-base-1.0", addon_path="UmerHA/Testing-ConrolNetXS-SDXL-depth" ) pipe.enable_sequential_cpu_offload() @@ -365,5 +486,5 @@ def test_depth(self): assert images[0].shape == (512, 512, 3) original_image = images[0, -3:, -3:, -1].flatten() - expected_image = np.array([0.4411, 0.3617, 0.2654, 0.266, 0.3449, 0.3898, 0.3745, 0.353, 0.326]) + expected_image = np.array([0.4082, 0.3879, 0.2781, 0.2655, 0.327 , 0.372 , 0.3762, 0.3444, 0.3122]) assert np.allclose(original_image, expected_image, atol=1e-04) From 1511d7dec1591c7a005cabcc7ae92641c271ae78 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Tue, 30 Jan 2024 15:21:06 +0100 Subject: [PATCH 34/75] Added comments and docstrings --- src/diffusers/models/controlnet_xs.py | 175 +++++++++--------- .../controlnet/pipeline_controlnet.py | 16 +- .../controlnet_xs/pipeline_controlnet_xs.py | 85 ++++++--- .../pipeline_controlnet_xs_sd_xl.py | 82 +++++--- .../controlnet_xs/test_controlnetxs.py | 29 ++- .../controlnet_xs/test_controlnetxs_sdxl.py | 38 ++-- 6 files changed, 238 insertions(+), 187 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index b874bd2c5f80..0f462d6e3ab1 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -89,6 +89,8 @@ class ControlNetXSAddon(ModelMixin, ConfigMixin): This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). + Like `ControlNetXSModel`, `ControlNetXSAddon` is compatible with StableDiffusion and StableDiffusion-XL. + It's default parameters are compatible with StableDiffusion. Parameters: conditioning_channels (`int`, defaults to 3): @@ -96,17 +98,18 @@ class ControlNetXSAddon(ModelMixin, ConfigMixin): conditioning_channel_order (`str`, defaults to `"rgb"`): The channel order of conditional image. Will convert to `rgb` if it's `bgr`. conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): - The tuple of output channel for each block in the `controlnet_cond_embedding` layer. + The tuple of output channels for each block in the `controlnet_cond_embedding` layer. time_embedding_input_dim (`int`, defaults to 320): Dimension of input into time embedding. Needs to be same as in the base model. time_embedding_dim (`int`, defaults to 1280): Dimension of output from time embedding. Needs to be same as in the base model. - learn_time_embedding (`bool`, defaults to `False`): todo - Whether the time embedding should be learned or fixed. - channels_base (`Dict[str, List[Tuple[int]]]`): todo - Base channel configurations for the model's layers. + learn_time_embedding (`bool`, defaults to `False`): + Whether a time embedding should be learned. If yes, `ControlNetXSModel` will combine the time embeddings of the base model and the addon. + If no, `ControlNetXSModel` will use the base model's time embedding. + channels_base (`Dict[str, List[Tuple[int]]]`, defaults to `ControlNetXSAddon.gather_base_subblock_sizes((320,640,1280,1280))`): + Channels of each subblock of the base model. Use `ControlNetXSAddon.gather_base_subblock_sizes` to obtain them. attention_head_dim (`list[int]`, defaults to `[4]`): - The dimension of the attention heads. + The dimension of the attention heads. block_out_channels (`list[int]`, defaults to `[4, 8, 16, 16]`): The tuple of output channels for each block. cross_attention_dim (`int`, defaults to 1024): @@ -119,9 +122,9 @@ class ControlNetXSAddon(ModelMixin, ConfigMixin): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. upcast_attention (`bool`, defaults to `True`): - todo - norm_num_groups (`int`, defaults to 32): - If `None`, normalization and activation layers is skipped in post-processing. # todo: is actually max_norm_num_groups + Whether the attention computation should always be upcasted. + max_norm_num_groups (`int`, defaults to 32): + Maximum number of groups in group normal. The actual number will the the largest divisor of the respective channels, that is <= max_norm_num_groups. """ @staticmethod @@ -187,21 +190,21 @@ def from_unet( conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), ): r""" - Instantiate a [`ControlNetXSAddon`] from [`UNet2DConditionModel`]. + Instantiate a [`ControlNetXSAddon`] from a [`UNet2DConditionModel`]. Parameters: base_model (`UNet2DConditionModel`): - The UNet model we want to control. The dimensions of the ControlNetXSModel will be adapted to it. + The UNet model we want to control. The dimensions of the ControlNetXSAddon will be adapted to it. size_ratio (float, *optional*, defaults to `None`): - When given, block_out_channels is set to a relative fraction of the base model's block_out_channels. + When given, block_out_channels is set to a fraction of the base model's block_out_channels. Either this or `block_out_channels` must be given. - block_out_channels (`Tuple[int]`, *optional*, defaults to `None`): + block_out_channels (`List[int]`, *optional*, defaults to `None`): Down blocks output channels in control model. Either this or `size_ratio` must be given. - num_attention_heads (`Union[int, Tuple[int]]`, *optional*, defaults to `None`): + num_attention_heads (`List[int]`, *optional*, defaults to `None`): The dimension of the attention heads. The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. learn_time_embedding (`bool`, defaults to `False`): Whether the `ControlNetXSAddon` should learn a time embedding. - conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): + conditioning_embedding_out_channels (`Tuple[int]`, defaults to `(16, 32, 96, 256)`): The tuple of output channel for each block in the `controlnet_cond_embedding` layer. """ @@ -220,7 +223,7 @@ def from_unet( # The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. num_attention_heads = base_model.config.attention_head_dim - norm_num_groups = math.gcd(*block_out_channels) + max_norm_num_groups = base_model.config.norm_num_groups time_embedding_input_dim = base_model.time_embedding.linear_1.in_features time_embedding_dim = base_model.time_embedding.linear_1.out_features @@ -235,7 +238,7 @@ def from_unet( sample_size=base_model.config.sample_size, transformer_layers_per_block=base_model.config.transformer_layers_per_block, upcast_attention=base_model.config.upcast_attention, - norm_num_groups=norm_num_groups, + max_norm_num_groups=max_norm_num_groups, conditioning_embedding_out_channels=conditioning_embedding_out_channels, time_embedding_input_dim=time_embedding_input_dim, time_embedding_dim=time_embedding_dim, @@ -250,19 +253,20 @@ def __init__( time_embedding_input_dim: Optional[int] = 320, time_embedding_dim: Optional[int] = 1280, learn_time_embedding: bool = False, - channels_base: Dict[str, List[Tuple[int]]] = { - "down - out": [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280], - "mid - out": 1280, - "up - in": [1280, 1280, 1280, 1280, 1280, 1280, 1280, 640, 640, 640, 320, 320], - }, + channels_base: Dict[str, List[Tuple[int]]] = gather_base_subblock_sizes((320, 640, 1280, 1280)), attention_head_dim: Union[int, Tuple[int]] = 4, - block_out_channels : Tuple[int] = (4, 8, 16, 16), - cross_attention_dim: int =1024, - down_block_types: Tuple[str]=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"), - sample_size: Optional[int]=96, # todo understand + block_out_channels: Tuple[int] = (4, 8, 16, 16), + cross_attention_dim: int = 1024, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + sample_size: Optional[int] = 96, transformer_layers_per_block: Union[int, Tuple[int]] = 1, upcast_attention: bool = True, - norm_num_groups: int = 32, # todo: rename max_norm_num_groups? + max_norm_num_groups: int = 32, ): super().__init__() @@ -337,7 +341,7 @@ def __init__( num_attention_heads=num_attention_heads[i], cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - norm_num_groups=norm_num_groups, + max_norm_num_groups=max_norm_num_groups, ) ) subblock_counter += 1 @@ -351,7 +355,7 @@ def __init__( num_attention_heads=num_attention_heads[i], cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - norm_num_groups=norm_num_groups, + max_norm_num_groups=max_norm_num_groups, ) ) subblock_counter += 1 @@ -376,14 +380,13 @@ def __init__( resnet_eps=1e-05, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads[-1], - resnet_groups=find_largest_factor(mid_in_channels, norm_num_groups), - resnet_groups_out=find_largest_factor(mid_out_channels, norm_num_groups), + resnet_groups=find_largest_factor(mid_in_channels, max_norm_num_groups), + resnet_groups_out=find_largest_factor(mid_out_channels, max_norm_num_groups), use_linear_projection=True, upcast_attention=upcast_attention, ) # 3 - Gather Channel Sizes - conditioning_embedding_out_channels channels_ctrl = { "down - out": [self.conv_in.out_channels] + [s.out_channels for s in self.down_subblocks], "mid - out": self.down_subblocks[-1].out_channels, @@ -397,14 +400,11 @@ def __init__( self.up_zero_convs_c2b = nn.ModuleList([]) # 4.1 - Connections from base encoder to ctrl encoder - # todo - better comment - # Information is passed from base to ctrl _before_ each subblock. We therefore use the 'in' channels. - # As the information is concatted in ctrl, we don't need to change channel sizes. So channels in = channels out. - for c in channels_base["down - out"]: # change down - in to down - out + # As the information is concatted to ctrl, the channels sizes don't change. + for c in channels_base["down - out"]: self.down_zero_convs_b2c.append(self._make_zero_conv(c, c)) # 4.2 - Connections from ctrl encoder to base encoder - # Information is passed from ctrl to base _after_ each subblock. We therefore use the 'out' channels. # As the information is added to base, the out-channels need to match base. for ch_base, ch_ctrl in zip(channels_base["down - out"], channels_ctrl["down - out"]): self.down_zero_convs_c2b.append(self._make_zero_conv(ch_ctrl, ch_base)) @@ -440,39 +440,29 @@ class ControlNetXSModel(nn.Module): This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). - Most of parameters for this model are passed into the [`UNet2DConditionModel`] it creates. Check the documentation - of [`UNet2DConditionModel`] for them. + `ControlNetXSModel` is compatible with StableDiffusion and StableDiffusion-XL. + It's default parameters are compatible with StableDiffusion. Parameters: - conditioning_channels (`int`, defaults to 3): - Number of channels of conditioning input (e.g. an image) - controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): - The channel order of conditional image. Will convert to `rgb` if it's `bgr`. - conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): - The tuple of output channel for each block in the `controlnet_cond_embedding` layer. - time_embedding_input_dim (`int`, defaults to 320): - Dimension of input into time embedding. Needs to be same as in the base model. - time_embedding_dim (`int`, defaults to 1280): - Dimension of output from time embedding. Needs to be same as in the base model. - learn_embedding (`bool`, defaults to `False`): - Whether to use time embedding of the control model. If yes, the time embedding is a linear interpolation of - the time embeddings of the control and base model with interpolation parameter `time_embedding_mix**3`. + base_model (`UNet2DConditionModel`): + The base UNet to control. + ctrl_addon (`ControlNetXSAddon`): + The control addon. time_embedding_mix (`float`, defaults to 1.0): - Linear interpolation parameter used if `learn_embedding` is `True`. A value of 1.0 means only the - control model's time embedding will be used. A value of 0.0 means only the base model's time embedding will be used. - channels_base (`Dict[str, List[Tuple[int]]]`): - Channel sizes of each subblock of base model. Use `gather_subblock_sizes` on your base model to compute it. + If 0, then only the base model's time embedding is be used. + If 1, then only the control model's time embedding is be used. + Otherwise, both are combined. """ @classmethod - def init_original(cls, base_model: UNet2DConditionModel, is_sdxl=True): + def init_original(cls, base_model: UNet2DConditionModel, is_sdxl=False): """ - Create a ControlNetXS model with the same parameters as in the original paper (https://github.com/vislearn/ControlNet-XS). + Create a `ControlNetXSModel` model with the same parameters as in the original paper (https://github.com/vislearn/ControlNet-XS). Parameters: base_model (`UNet2DConditionModel`): Base UNet model. Needs to be either StableDiffusion or StableDiffusion-XL. - is_sdxl (`bool`, defaults to `True`): + is_sdxl (`bool`, defaults to `False`): Whether passed `base_model` is a StableDiffusion-XL model. """ @@ -480,7 +470,7 @@ def get_dim_attn_heads(base_model: UNet2DConditionModel, size_ratio: float, num_ """ Currently, diffusers can only set the dimension of attention heads (see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why). The original ControlNet-XS model, however, define the number of attention heads. - That's why compute the dimensions needed to get the correct number of attention heads. + That's why we compute the dimensions needed to get the correct number of attention heads. """ block_out_channels = [int(size_ratio * c) for c in base_model.config.block_out_channels] dim_attn_heads = [math.ceil(c / num_attn_heads) for c in block_out_channels] @@ -503,17 +493,24 @@ def get_dim_attn_heads(base_model: UNet2DConditionModel, size_ratio: float, num_ num_attention_heads=get_dim_attn_heads(base_model, 0.0125, 8), ) - return cls(base_model=base_model, ctrl_model=controlnet_addon, time_embedding_mix=time_embedding_mix) + return cls(base_model=base_model, ctrl_addon=controlnet_addon, time_embedding_mix=time_embedding_mix) def __init__( self, base_model: UNet2DConditionModel, - ctrl_model: ControlNetXSAddon, + ctrl_addon: ControlNetXSAddon, time_embedding_mix: float = 1.0, ): super().__init__() - self.ctrl_model = ctrl_model + if time_embedding_mix < 0 or time_embedding_mix > 1: + raise ValueError("`time_embedding_mix` needs to be between 0 and 1.") + if time_embedding_mix < 1 and not ctrl_addon.config.learn_time_embedding: + raise ValueError( + "To use `time_embedding_mix` < 1, initialize `ctrl_addon` with `learn_time_embedding = True`" + ) + + self.ctrl_addon = ctrl_addon self.base_model = base_model self.time_embedding_mix = time_embedding_mix @@ -571,7 +568,7 @@ def dtype(self) -> torch.dtype: @torch.no_grad() def _check_if_vae_compatible(self, vae: AutoencoderKL): - condition_downscale_factor = 2 ** (len(self.ctrl_model.config.conditioning_embedding_out_channels) - 1) + condition_downscale_factor = 2 ** (len(self.ctrl_addon.config.conditioning_embedding_out_channels) - 1) vae_downscale_factor = 2 ** (len(vae.config.block_out_channels) - 1) compatible = condition_downscale_factor == vae_downscale_factor return compatible, condition_downscale_factor, vae_downscale_factor @@ -592,11 +589,9 @@ def forward( do_control: bool = True, ) -> Union[ControlNetXSOutput, Tuple]: """ - The [`ControlNetModel`] forward method. + The [`ControlNetXSModel`] forward method. Args: - base_model (`UNet2DConditionModel`): - The base unet model we want to control. sample (`torch.FloatTensor`): The noisy input tensor. timestep (`Union[torch.Tensor, float, int]`): @@ -617,12 +612,14 @@ def forward( An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. - added_cond_kwargs (`dict`): - Additional conditions for the Stable Diffusion XL UNet. cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): A kwargs dictionary that if specified is passed along to the `AttnProcessor`. + added_cond_kwargs (`dict`): + Additional conditions for the Stable Diffusion XL UNet. return_dict (`bool`, defaults to `True`): Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. + do_control (`bool`, defaults to `True`): + If `False`, the input is run only through the base model. Returns: [`~models.controlnetxs.ControlNetXSOutput`] **or** `tuple`: @@ -640,11 +637,11 @@ def forward( attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, - return_dict=return_dict + return_dict=return_dict, ) # check channel order - if self.ctrl_model.config.conditioning_channel_order == "bgr": + if self.ctrl_addon.config.conditioning_channel_order == "bgr": controlnet_cond = torch.flip(controlnet_cond, dims=[1]) # prepare attention_mask @@ -676,8 +673,8 @@ def forward( # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) - if self.ctrl_model.config.learn_time_embedding: - ctrl_temb = self.ctrl_model.time_embedding(t_emb, timestep_cond) + if self.ctrl_addon.config.learn_time_embedding: + ctrl_temb = self.ctrl_addon.time_embedding(t_emb, timestep_cond) base_temb = self.base_model.time_embedding(t_emb, timestep_cond) interpolation_param = self.time_embedding_mix**0.3 @@ -728,7 +725,7 @@ def forward( cemb = encoder_hidden_states # Preparation - guided_hint = self.ctrl_model.controlnet_cond_embedding(controlnet_cond) + guided_hint = self.ctrl_addon.controlnet_cond_embedding(controlnet_cond) h_ctrl = h_base = sample hs_base, hs_ctrl = [], [] @@ -736,21 +733,21 @@ def forward( # Cross Control # Let's first define variables to shorten notation base_down_subblocks = self.base_down_subblocks - ctrl_down_subblocks = self.ctrl_model.down_subblocks + ctrl_down_subblocks = self.ctrl_addon.down_subblocks - down_zero_convs_b2c = self.ctrl_model.down_zero_convs_b2c - down_zero_convs_c2b = self.ctrl_model.down_zero_convs_c2b - mid_zero_convs_c2b = self.ctrl_model.mid_zero_convs_c2b - up_zero_convs_c2b = self.ctrl_model.up_zero_convs_c2b + down_zero_convs_b2c = self.ctrl_addon.down_zero_convs_b2c + down_zero_convs_c2b = self.ctrl_addon.down_zero_convs_c2b + mid_zero_convs_c2b = self.ctrl_addon.mid_zero_convs_c2b + up_zero_convs_c2b = self.ctrl_addon.up_zero_convs_c2b # 1 - conv in & down - # The base -> ctrl connections are "delayed" by 1 subblock, because we want to "wait" to ensure the new information from the last ctrl -> base connection is also considered + # The base -> ctrl connections are "delayed" by 1 subblock, because we want to "wait" to ensure the new information from the last ctrl -> base connection is also considered. # Therefore, the connections iterate over: # ctrl -> base: conv_in | subblock 1 | ... | subblock n # base -> ctrl: | subblock 1 | ... | subblock n | mid block h_base = self.base_model.conv_in(h_base) - h_ctrl = self.ctrl_model.conv_in(h_ctrl) + h_ctrl = self.ctrl_addon.conv_in(h_ctrl) if guided_hint is not None: h_ctrl += guided_hint h_base = h_base + down_zero_convs_c2b[0](h_ctrl) * conditioning_scale # add ctrl -> base @@ -779,8 +776,12 @@ def forward( h_ctrl = torch.cat([h_ctrl, down_zero_convs_b2c[-1](h_base)], dim=1) # concat base -> ctrl # 2 - mid - h_base = self.base_model.mid_block(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # apply base subblock - h_ctrl = self.ctrl_model.mid_block(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # apply ctrl subblock + h_base = self.base_model.mid_block( + h_base, temb, cemb, attention_mask, cross_attention_kwargs + ) # apply base subblock + h_ctrl = self.ctrl_addon.mid_block( + h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs + ) # apply ctrl subblock h_base = h_base + mid_zero_convs_c2b(h_ctrl) * conditioning_scale # add ctrl -> base # 3 - up @@ -825,7 +826,7 @@ def __init__( in_channels: Optional[int] = None, out_channels: Optional[int] = None, temb_channels: Optional[int] = None, - norm_num_groups: Optional[int] = 32, + max_norm_num_groups: Optional[int] = 32, has_crossattn=False, transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1, num_attention_heads: Optional[int] = 1, @@ -846,8 +847,8 @@ def __init__( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, - groups=find_largest_factor(in_channels, max_factor=norm_num_groups), - groups_out=find_largest_factor(out_channels, max_factor=norm_num_groups), + groups=find_largest_factor(in_channels, max_factor=max_norm_num_groups), + groups_out=find_largest_factor(out_channels, max_factor=max_norm_num_groups), eps=1e-5, ) @@ -860,7 +861,7 @@ def __init__( cross_attention_dim=cross_attention_dim, use_linear_projection=True, upcast_attention=upcast_attention, - norm_num_groups=find_largest_factor(out_channels, max_factor=norm_num_groups), + norm_num_groups=find_largest_factor(out_channels, max_factor=max_norm_num_groups), ) else: self.attention = None diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index bb6a9a0ba58a..ddc6ced2f80b 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -1171,7 +1171,7 @@ def __call__( is_controlnet_compiled = is_compiled_module(self.controlnet) is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): + for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: @@ -1198,7 +1198,10 @@ def __call__( controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] - print(f'Denoising step {i} > Right before controlnet application : Device type of controlnet >> ',self.controlnet.device.type) + print( + f"Denoising step {i} > Right before controlnet application : Device type of controlnet >> ", + self.controlnet.device.type, + ) down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, @@ -1208,7 +1211,10 @@ def __call__( guess_mode=guess_mode, return_dict=False, ) - print(f'Denoising step {i} > Right after controlnet application : Device type of controlnet >> ',self.controlnet.device.type) + print( + f"Denoising step {i} > Right after controlnet application : Device type of controlnet >> ", + self.controlnet.device.type, + ) if guess_mode and self.do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. @@ -1242,7 +1248,7 @@ def __call__( callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] - print('btw, calling callback_on_step_end') + print("btw, calling callback_on_step_end") callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) @@ -1253,7 +1259,7 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: - print('btw, calling callback') + print("btw, calling callback") step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index 080d5d504d26..b44ff740ddf9 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -22,8 +22,8 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ImageProjection, ControlNetXSAddon, ControlNetXSModel +from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetXSAddon, ControlNetXSModel, ImageProjection from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -66,12 +66,11 @@ >>> # initialize the models and pipeline >>> controlnet_conditioning_scale = 0.5 - >>> controlnet = ControlNetXSModel.from_pretrained( - ... "UmerHA/ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16 - ... ) + >>> pipe = StableDiffusionControlNetXSPipeline.from_pretrained( - ... "stabilityai/stable-diffusion-2-1", controlnet=controlnet, torch_dtype=torch.float16 - ... ) + >>> base_path="stabilityai/stable-diffusion-2-1", base_kwargs=dict(torch_dtype=torch.float16), + >>> addon_path="UmerHA/Testing-ConrolNetXS-SD2.1-canny", addon_kwargs=dict(torch_dtype=torch.float16), + >>> ) >>> pipe.enable_model_cpu_offload() >>> # get canny image @@ -89,10 +88,9 @@ class StableDiffusionControlNetXSPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin ): r""" - # todo Pipeline for text-to-image generation using Stable Diffusion with ControlNet-XS guidance. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods @@ -100,6 +98,9 @@ class StableDiffusionControlNetXSPipeline( The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters - [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files Args: @@ -109,10 +110,8 @@ class StableDiffusionControlNetXSPipeline( Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. - unet ([`UNet2DConditionModel`]): - A `UNet2DConditionModel` to denoise the encoded image latents. - controlnet_addon ([`ControlNetXSAddon`]): - Provides additional conditioning to the `unet` during the denoising process. + controlnet ([`ControlNetXSModel`]): + A model containing a base UNet and a control addon. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. @@ -189,26 +188,50 @@ def __init__( @classmethod def from_pretrained(cls, base_path, addon_path, base_kwargs={}, addon_kwargs={}): """ - todo: docstring + Instantiates pipeline from a `StableDiffusionPipeline` and a `ControlNetXSAddon`. + + Arguments: + base_path (`str` or `os.PathLike`): + Directory to load underlying `StableDiffusionPipeline` from. + addon_path (`str` or `os.PathLike`): + Directory to load underlying `ControlNetXSAddon` model from. + base_kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~StableDiffusionPipeline.from_pretrained`] method. + addon_kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~ControlNetXSAddon.from_pretrained`] method. """ + components = StableDiffusionPipeline.from_pretrained(base_path, **base_kwargs).components controlnet_addon = ControlNetXSAddon.from_pretrained(addon_path, **addon_kwargs) - # todo: what if StableDiffusionPipeline has more params than StableDiffusionControlNetXSPipeline - # eg if some features are not implemented in cnxs yet? - unet = components["unet"] - components = {k:v for k,v in components.items() if k != "unet"} + components = {k: v for k, v in components.items() if k != "unet"} controlnet = ControlNetXSModel(unet, controlnet_addon) return StableDiffusionControlNetXSPipeline(controlnet=controlnet, **components) def save_pretrained(self, base_path, addon_path, base_kwargs={}, addon_kwargs={}): - """todo docs""" - components = {k:v for k,v in self.components.items() if k!="controlnet"} + """ + + Separately save the underlying `StableDiffusionPipeline` and the `ControlNetXSAddon` so the pipeline is easily reloaded using the + [`~StableDiffusionControlNetXSPipeline.from_pretrained`] class method. + + Arguments: + base_path (`str` or `os.PathLike`): + Directory to save underlying `StableDiffusionPipeline` to. Will be created if it doesn't exist. + addon_path (`str` or `os.PathLike`): + Directory to save underlying `ControlNetXSAddon` model to. Will be created if it doesn't exist. + base_kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~StableDiffusionPipeline.save_pretrained`] method. + addon_kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~ControlNetXSAddon.save_pretrained`] method. + + """ + + components = {k: v for k, v in self.components.items() if k != "controlnet"} components["unet"] = self.components["controlnet"].base_model - controlnet_addon = self.components["controlnet"].ctrl_model + controlnet_addon = self.components["controlnet"].ctrl_addon StableDiffusionPipeline(**components).save_pretrained(base_path, **base_kwargs) controlnet_addon.save_pretrained(addon_path, **addon_kwargs) @@ -825,17 +848,13 @@ def __call__( negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. - callback (`Callable`, *optional*): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). @@ -850,7 +869,15 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. Examples: Returns: @@ -927,7 +954,7 @@ def __call__( lora_scale=text_encoder_lora_scale, clip_skip=clip_skip, ) - + # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index 2789cbe77fcb..e54511fa782a 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -21,7 +21,7 @@ import torch.nn.functional as F from transformers import ( CLIPImageProcessor, - CLIPTextModel, + CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, @@ -31,7 +31,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ControlNetXSAddon, ControlNetXSModel, UNet2DConditionModel +from ...models import AutoencoderKL, ControlNetXSAddon, ControlNetXSModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -40,7 +40,14 @@ ) from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers -from ...utils import USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion_xl import StableDiffusionXLPipeline @@ -76,11 +83,12 @@ >>> # initialize the models and pipeline >>> controlnet_conditioning_scale = 0.5 # recommended for good generalization - >>> controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SDXL-canny", torch_dtype=torch.float16) >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) - >>> pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( - ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, torch_dtype=torch.float16 - ... ) + >>> pipe = StableDiffusionControlNetXSPipeline.from_pretrained( + >>> base_path="stabilityai/stable-diffusion-xl-base-1.0", base_kwargs=dict(vae=vae, torch_dtype=torch.float16), + >>> addon_path="UmerHA/Testing-ConrolNetXS-SDXL-canny", addon_kwargs=dict(torch_dtype=torch.float16), + >>> ) + >>> pipe.enable_model_cpu_offload() >>> # get canny image @@ -187,7 +195,7 @@ def __init__( scheduler=scheduler, feature_extractor=feature_extractor, image_encoder=image_encoder, - ) + ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( @@ -205,26 +213,50 @@ def __init__( @classmethod def from_pretrained(cls, base_path, addon_path, base_kwargs={}, addon_kwargs={}): """ - todo: docstring + Instantiates pipeline from a `StableDiffusionXLPipeline` and a `ControlNetXSAddon`. + + Arguments: + base_path (`str` or `os.PathLike`): + Directory to load underlying `StableDiffusionXLPipeline` from. + addon_path (`str` or `os.PathLike`): + Directory to load underlying `ControlNetXSAddon` model from. + base_kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~StableDiffusionXLPipeline.from_pretrained`] method. + addon_kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~ControlNetXSAddon.from_pretrained`] method. """ + components = StableDiffusionXLPipeline.from_pretrained(base_path, **base_kwargs).components controlnet_addon = ControlNetXSAddon.from_pretrained(addon_path, **addon_kwargs) - # todo: what if StableDiffusionXLPipeline has more params than StableDiffusionControlNetXSPipeline - # eg if some features are not implemented in cnxs yet? - unet = components["unet"] - components = {k:v for k,v in components.items() if k != "unet"} + components = {k: v for k, v in components.items() if k != "unet"} controlnet = ControlNetXSModel(unet, controlnet_addon) return StableDiffusionXLControlNetXSPipeline(controlnet=controlnet, **components) def save_pretrained(self, base_path, addon_path, base_kwargs={}, addon_kwargs={}): - """todo docs""" - components = {k:v for k,v in self.components.items() if k!="controlnet"} + """ + + Separately save the underlying `StableDiffusionXLPipeline` and the `ControlNetXSAddon` so the pipeline is easily reloaded using the + [`~StableDiffusionControlNetXSPipeline.from_pretrained`] class method. + + Arguments: + base_path (`str` or `os.PathLike`): + Directory to save underlying `StableDiffusionXLPipeline` to. Will be created if it doesn't exist. + addon_path (`str` or `os.PathLike`): + Directory to save underlying `ControlNetXSAddon` model to. Will be created if it doesn't exist. + base_kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~StableDiffusionXLPipeline.save_pretrained`] method. + addon_kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~ControlNetXSAddon.save_pretrained`] method. + + """ + + components = {k: v for k, v in self.components.items() if k != "controlnet"} components["unet"] = self.components["controlnet"].base_model - controlnet_addon = self.components["controlnet"].ctrl_model + controlnet_addon = self.components["controlnet"].ctrl_addon StableDiffusionXLPipeline(**components).save_pretrained(base_path, **base_kwargs) controlnet_addon.save_pretrained(addon_path, **addon_kwargs) @@ -262,7 +294,6 @@ def disable_vae_tiling(self): """ self.vae.disable_tiling() - # todo: check if copy def encode_prompt( self, prompt: str, @@ -321,6 +352,8 @@ def encode_prompt( Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ + # Note: this is almost an exact copy of `StableDiffusionXLPipeline.encode_prompt` except that `sefl.controlnet` is used instead of `self.unet` + device = device or self._execution_device # set lora scale so that monkey patched LoRA @@ -933,12 +966,6 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. - callback (`Callable`, *optional*): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). @@ -981,6 +1008,15 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. Examples: diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py index 98b4a1bf8b40..f8a766068c9b 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py @@ -14,9 +14,9 @@ # limitations under the License. import gc +import tempfile import traceback import unittest -import tempfile import numpy as np import torch @@ -57,7 +57,7 @@ PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, - to_np + to_np, ) @@ -140,7 +140,7 @@ def get_dummy_components(self, time_cond_proj_dim=None): learn_time_embedding=True, conditioning_embedding_out_channels=(16, 32), ) - controlnet = ControlNetXSModel(base_model=unet, ctrl_model=controlnet_addon) + controlnet = ControlNetXSModel(base_model=unet, ctrl_addon=controlnet_addon) torch.manual_seed(0) scheduler = DDIMScheduler( beta_start=0.00085, @@ -263,14 +263,11 @@ def test_save_load_local(self, expected_max_difference=5e-4): pipe.save_pretrained( base_path=tmpdir_components, addon_path=tmpdir_addon, - base_kwargs=dict(safe_serialization=False), - addon_kwargs=dict(safe_serialization=False), + base_kwargs={"safe_serialization": False}, + addon_kwargs={"safe_serialization": False}, ) - pipe_loaded = self.pipeline_class.from_pretrained( - base_path=tmpdir_components, - addon_path=tmpdir_addon - ) + pipe_loaded = self.pipeline_class.from_pretrained(base_path=tmpdir_components, addon_path=tmpdir_addon) for component in pipe_loaded.components.values(): if hasattr(component, "set_default_attn_processor"): @@ -304,18 +301,14 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4): with tempfile.TemporaryDirectory() as tmpdir_components: with tempfile.TemporaryDirectory() as tmpdir_addon: - pipe.save_pretrained( base_path=tmpdir_components, addon_path=tmpdir_addon, - base_kwargs=dict(safe_serialization=False), - addon_kwargs=dict(safe_serialization=False), + base_kwargs={"safe_serialization": False}, + addon_kwargs={"safe_serialization": False}, ) - pipe_loaded = self.pipeline_class.from_pretrained( - base_path=tmpdir_components, - addon_path=tmpdir_addon - ) + pipe_loaded = self.pipeline_class.from_pretrained(base_path=tmpdir_components, addon_path=tmpdir_addon) for component in pipe_loaded.components.values(): if hasattr(component, "set_default_attn_processor"): @@ -363,7 +356,7 @@ def test_canny(self): image = output.images[0] assert image.shape == (768, 512, 3) - + original_image = image[-3:, -3:, -1].flatten() expected_image = np.array([0.1462, 0.1518, 0.1583, 0.1332, 0.1655, 0.1629, 0.1646, 0.1595, 0.1762]) assert np.allclose(original_image, expected_image, atol=1e-04) @@ -389,7 +382,7 @@ def test_depth(self): assert image.shape == (512, 512, 3) original_image = image[-3:, -3:, -1].flatten() - expected_image = np.array([0.1504, 0.1448, 0.1742, 0.155 , 0.1553, 0.1833, 0.1694, 0.1833, 0.2354]) + expected_image = np.array([0.1504, 0.1448, 0.1742, 0.155, 0.1553, 0.1833, 0.1694, 0.1833, 0.2354]) assert np.allclose(original_image, expected_image, atol=1e-04) @require_python39_or_higher diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py index c854561d78d1..8578148ac7a7 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py @@ -14,8 +14,8 @@ # limitations under the License. import gc -import unittest import tempfile +import unittest import numpy as np import torch @@ -46,7 +46,7 @@ PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, - to_np + to_np, ) @@ -94,7 +94,7 @@ def get_dummy_components(self): learn_time_embedding=True, conditioning_embedding_out_channels=(16, 32), ) - controlnet = ControlNetXSModel(base_model=unet, ctrl_model=controlnet_addon) + controlnet = ControlNetXSModel(base_model=unet, ctrl_addon=controlnet_addon) torch.manual_seed(0) scheduler = EulerDiscreteScheduler( beta_start=0.00085, @@ -187,10 +187,6 @@ def test_xformers_attention_forwardGenerator_pass(self): def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=2e-3) - # copied from test_controlnet_sdxl.py - def test_save_load_optional_components(self): - self._test_save_load_optional_components() - # copied from test_controlnet_sdxl.py @require_torch_gpu def test_stable_diffusion_xl_offloads(self): @@ -336,14 +332,11 @@ def test_save_load_local(self, expected_max_difference=5e-4): pipe.save_pretrained( base_path=tmpdir_components, addon_path=tmpdir_addon, - base_kwargs=dict(safe_serialization=False), - addon_kwargs=dict(safe_serialization=False), + base_kwargs={"safe_serialization": False}, + addon_kwargs={"safe_serialization": False}, ) - pipe_loaded = self.pipeline_class.from_pretrained( - base_path=tmpdir_components, - addon_path=tmpdir_addon - ) + pipe_loaded = self.pipeline_class.from_pretrained(base_path=tmpdir_components, addon_path=tmpdir_addon) for component in pipe_loaded.components.values(): if hasattr(component, "set_default_attn_processor"): @@ -398,18 +391,14 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4): with tempfile.TemporaryDirectory() as tmpdir_components: with tempfile.TemporaryDirectory() as tmpdir_addon: - pipe.save_pretrained( base_path=tmpdir_components, addon_path=tmpdir_addon, - base_kwargs=dict(safe_serialization=False), - addon_kwargs=dict(safe_serialization=False), + base_kwargs={"safe_serialization": False}, + addon_kwargs={"safe_serialization": False}, ) - pipe_loaded = self.pipeline_class.from_pretrained( - base_path=tmpdir_components, - addon_path=tmpdir_addon - ) + pipe_loaded = self.pipeline_class.from_pretrained(base_path=tmpdir_components, addon_path=tmpdir_addon) for component in pipe_loaded.components.values(): if hasattr(component, "set_default_attn_processor"): @@ -437,6 +426,7 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4): max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() self.assertLess(max_diff, expected_max_difference) + @slow @require_torch_gpu class StableDiffusionXLControlNetXSPipelineSlowTests(unittest.TestCase): @@ -447,8 +437,7 @@ def tearDown(self): def test_canny(self): pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( - base_path="stabilityai/stable-diffusion-xl-base-1.0", - addon_path="UmerHA/Testing-ConrolNetXS-SDXL-canny" + base_path="stabilityai/stable-diffusion-xl-base-1.0", addon_path="UmerHA/Testing-ConrolNetXS-SDXL-canny" ) pipe.enable_sequential_cpu_offload() pipe.set_progress_bar_config(disable=None) @@ -469,8 +458,7 @@ def test_canny(self): def test_depth(self): pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( - base_path="stabilityai/stable-diffusion-xl-base-1.0", - addon_path="UmerHA/Testing-ConrolNetXS-SDXL-depth" + base_path="stabilityai/stable-diffusion-xl-base-1.0", addon_path="UmerHA/Testing-ConrolNetXS-SDXL-depth" ) pipe.enable_sequential_cpu_offload() pipe.set_progress_bar_config(disable=None) @@ -486,5 +474,5 @@ def test_depth(self): assert images[0].shape == (512, 512, 3) original_image = images[0, -3:, -3:, -1].flatten() - expected_image = np.array([0.4082, 0.3879, 0.2781, 0.2655, 0.327 , 0.372 , 0.3762, 0.3444, 0.3122]) + expected_image = np.array([0.4082, 0.3879, 0.2781, 0.2655, 0.327, 0.372, 0.3762, 0.3444, 0.3122]) assert np.allclose(original_image, expected_image, atol=1e-04) From 6041bb188d000854b63a2e7dd82299448e04a007 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Tue, 30 Jan 2024 15:31:17 +0100 Subject: [PATCH 35/75] Fixed copies --- .../controlnet_xs/pipeline_controlnet_xs.py | 10 +++---- .../pipeline_controlnet_xs_sd_xl.py | 10 +++---- .../versatile_diffusion/modeling_text_unet.py | 26 ++++++++++------ src/diffusers/utils/dummy_pt_objects.py | 30 +++++++++++++++++++ .../dummy_torch_and_transformers_objects.py | 30 +++++++++++++++++++ 5 files changed, 87 insertions(+), 19 deletions(-) diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index b44ff740ddf9..800fefc9386a 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -749,28 +749,28 @@ def disable_freeu(self): """Disables the FreeU mechanism if enabled.""" self.unet.disable_freeu() - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale def guidance_scale(self): return self._guidance_scale - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip def clip_skip(self): return self._clip_skip - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs def cross_attention_kwargs(self): return self._cross_attention_kwargs - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps def num_timesteps(self): return self._num_timesteps diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index e54511fa782a..fba3d0fc03d1 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -836,28 +836,28 @@ def disable_freeu(self): # todo: check if works self.controlnet.disable_freeu() - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale def guidance_scale(self): return self._guidance_scale - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip def clip_skip(self): return self._clip_skip - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs def cross_attention_kwargs(self): return self._cross_attention_kwargs - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps def num_timesteps(self): return self._num_timesteps diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 6f95112c3d50..d86a5a018425 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -2191,6 +2191,7 @@ def __init__( self, in_channels: int, temb_channels: int, + out_channels: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1, @@ -2198,6 +2199,7 @@ def __init__( resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, + resnet_groups_out: Optional[int] = None, resnet_pre_norm: bool = True, num_attention_heads: int = 1, output_scale_factor: float = 1.0, @@ -2209,9 +2211,14 @@ def __init__( ): super().__init__() + out_channels = out_channels or in_channels + self.in_channels = in_channels + self.out_channels = out_channels + self.has_cross_attention = True self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + resnet_groups_out = resnet_groups_out or resnet_groups # support for variable transformer layers per block if isinstance(transformer_layers_per_block, int): @@ -2221,10 +2228,11 @@ def __init__( resnets = [ ResnetBlockFlat( in_channels=in_channels, - out_channels=in_channels, + out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, + groups_out=resnet_groups_out, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, @@ -2239,11 +2247,11 @@ def __init__( attentions.append( Transformer2DModel( num_attention_heads, - in_channels // num_attention_heads, - in_channels=in_channels, + out_channels // num_attention_heads, + in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, + norm_num_groups=resnet_groups_out, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, attention_type=attention_type, @@ -2253,8 +2261,8 @@ def __init__( attentions.append( DualTransformer2DModel( num_attention_heads, - in_channels // num_attention_heads, - in_channels=in_channels, + out_channels // num_attention_heads, + in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, @@ -2262,11 +2270,11 @@ def __init__( ) resnets.append( ResnetBlockFlat( - in_channels=in_channels, - out_channels=in_channels, + in_channels=out_channels, + out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, - groups=resnet_groups, + groups=resnet_groups_out, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index d306a3575b1f..0a7e5f3bd787 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -92,6 +92,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class ControlNetXSAddon(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class ControlNetXSModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class Kandinsky3UNet(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 2eb9599658d9..ae6c6c916065 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -782,6 +782,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class StableDiffusionControlNetXSPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionDepth2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -1127,6 +1142,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class StableDiffusionXLControlNetXSPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 1cdcb27ab5b626a26664ba2e5c863c5554879349 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Tue, 30 Jan 2024 16:45:40 +0100 Subject: [PATCH 36/75] Added docs ; Updates slow tests --- docs/source/en/api/pipelines/controlnetxs.md | 39 ++++++++++++++++ .../en/api/pipelines/controlnetxs_sdxl.md | 45 +++++++++++++++++++ .../controlnet_xs/pipeline_controlnet_xs.py | 2 +- .../pipeline_controlnet_xs_sd_xl.py | 2 +- .../controlnet_xs/test_controlnetxs.py | 14 +++--- .../controlnet_xs/test_controlnetxs_sdxl.py | 2 +- 6 files changed, 94 insertions(+), 10 deletions(-) create mode 100644 docs/source/en/api/pipelines/controlnetxs.md create mode 100644 docs/source/en/api/pipelines/controlnetxs_sdxl.md diff --git a/docs/source/en/api/pipelines/controlnetxs.md b/docs/source/en/api/pipelines/controlnetxs.md new file mode 100644 index 000000000000..2d4ae7b8ce46 --- /dev/null +++ b/docs/source/en/api/pipelines/controlnetxs.md @@ -0,0 +1,39 @@ + + +# ControlNet-XS + +ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results. + +Like the original ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process. + +ControlNet-XS generates images with comparable quality to a regular ControlNet, but it is 20-25% faster ([see benchmark](https://github.com/UmerHA/controlnet-xs-benchmark/blob/main/Speed%20Benchmark.ipynb) with StableDiffusion-XL) and uses ~45% less memory. + +Here's the overview from the [project page](https://vislearn.github.io/ControlNet-XS/): + +*With increasing computing capabilities, current model architectures appear to follow the trend of simply upscaling all components without validating the necessity for doing so. In this project we investigate the size and architectural design of ControlNet [Zhang et al., 2023] for controlling the image generation process with stable diffusion-based models. We show that a new architecture with as little as 1% of the parameters of the base model achieves state-of-the art results, considerably better than ControlNet in terms of FID score. Hence we call it ControlNet-XS. We provide the code for controlling StableDiffusion-XL [Podell et al., 2023] (Model B, 48M Parameters) and StableDiffusion 2.1 [Rombach et al. 2022] (Model B, 14M Parameters), all under openrail license.* + +This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️ + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines. + + + +## StableDiffusionControlNetXSPipeline +[[autodoc]] StableDiffusionControlNetXSPipeline + - all + - __call__ + +## StableDiffusionPipelineOutput +[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput diff --git a/docs/source/en/api/pipelines/controlnetxs_sdxl.md b/docs/source/en/api/pipelines/controlnetxs_sdxl.md new file mode 100644 index 000000000000..31075c0ef96a --- /dev/null +++ b/docs/source/en/api/pipelines/controlnetxs_sdxl.md @@ -0,0 +1,45 @@ + + +# ControlNet-XS with Stable Diffusion XL + +ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results. + +Like the original ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process. + +ControlNet-XS generates images with comparable quality to a regular ControlNet, but it is 20-25% faster ([see benchmark](https://github.com/UmerHA/controlnet-xs-benchmark/blob/main/Speed%20Benchmark.ipynb)) and uses ~45% less memory. + +Here's the overview from the [project page](https://vislearn.github.io/ControlNet-XS/): + +*With increasing computing capabilities, current model architectures appear to follow the trend of simply upscaling all components without validating the necessity for doing so. In this project we investigate the size and architectural design of ControlNet [Zhang et al., 2023] for controlling the image generation process with stable diffusion-based models. We show that a new architecture with as little as 1% of the parameters of the base model achieves state-of-the art results, considerably better than ControlNet in terms of FID score. Hence we call it ControlNet-XS. We provide the code for controlling StableDiffusion-XL [Podell et al., 2023] (Model B, 48M Parameters) and StableDiffusion 2.1 [Rombach et al. 2022] (Model B, 14M Parameters), all under openrail license.* + +This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️ + + + +🧪 Many of the SDXL ControlNet checkpoints are experimental, and there is a lot of room for improvement. Feel free to open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) and leave us feedback on how we can improve! + + + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines. + + + +## StableDiffusionXLControlNetXSPipeline +[[autodoc]] StableDiffusionXLControlNetXSPipeline + - all + - __call__ + +## StableDiffusionPipelineOutput +[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index 800fefc9386a..aee2a5123a82 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -68,7 +68,7 @@ >>> controlnet_conditioning_scale = 0.5 >>> pipe = StableDiffusionControlNetXSPipeline.from_pretrained( - >>> base_path="stabilityai/stable-diffusion-2-1", base_kwargs=dict(torch_dtype=torch.float16), + >>> base_path="stabilityai/stable-diffusion-2-1-base", base_kwargs=dict(torch_dtype=torch.float16), >>> addon_path="UmerHA/Testing-ConrolNetXS-SD2.1-canny", addon_kwargs=dict(torch_dtype=torch.float16), >>> ) >>> pipe.enable_model_cpu_offload() diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index fba3d0fc03d1..a204be74ed3e 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -776,7 +776,7 @@ def _get_add_time_ids( add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( - self.controlnet.base_model.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + self.controlnet.base_model.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.controlnet.base_model.add_embedding.linear_1.in_features diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py index f8a766068c9b..49330dceb64e 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py @@ -70,10 +70,10 @@ def _test_stable_diffusion_compile(in_queue, out_queue, timeout): try: _ = in_queue.get(timeout=timeout) - controlnet_addon = ControlNetXSAddon.from_pretrained("todo umer") - pipe = StableDiffusionControlNetXSPipeline.from_pretrained( - "stabilityai/stable-diffusion-2-1", safety_checker=None, controlnet_addon=controlnet_addon + base_path="stabilityai/stable-diffusion-2-1-base", + base_kwargs={"safety_checker": None}, + addon_path="UmerHA/Testing-ConrolNetXS-SD2.1-canny", ) pipe.to("cuda") pipe.set_progress_bar_config(disable=None) @@ -339,7 +339,7 @@ def tearDown(self): def test_canny(self): pipe = StableDiffusionControlNetXSPipeline.from_pretrained( - base_path="stabilityai/stable-diffusion-2-1", + base_path="stabilityai/stable-diffusion-2-1-base", addon_path="UmerHA/Testing-ConrolNetXS-SD2.1-canny", ) pipe.enable_model_cpu_offload() @@ -358,12 +358,12 @@ def test_canny(self): assert image.shape == (768, 512, 3) original_image = image[-3:, -3:, -1].flatten() - expected_image = np.array([0.1462, 0.1518, 0.1583, 0.1332, 0.1655, 0.1629, 0.1646, 0.1595, 0.1762]) + expected_image = np.array([0.1276, 0.1405, 0.1474, 0.1188, 0.1559, 0.1496, 0.1569, 0.1478, 0.1706]) assert np.allclose(original_image, expected_image, atol=1e-04) def test_depth(self): pipe = StableDiffusionControlNetXSPipeline.from_pretrained( - base_path="stabilityai/stable-diffusion-2-1", + base_path="stabilityai/stable-diffusion-2-1-base", addon_path="UmerHA/Testing-ConrolNetXS-SD2.1-depth", ) pipe.enable_model_cpu_offload() @@ -382,7 +382,7 @@ def test_depth(self): assert image.shape == (512, 512, 3) original_image = image[-3:, -3:, -1].flatten() - expected_image = np.array([0.1504, 0.1448, 0.1742, 0.155, 0.1553, 0.1833, 0.1694, 0.1833, 0.2354]) + expected_image = np.array([0.1101, 0.1026, 0.1212, 0.114, 0.1169, 0.1266, 0.1191, 0.1266, 0.1712]) assert np.allclose(original_image, expected_image, atol=1e-04) @require_python39_or_higher diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py index 8578148ac7a7..173b96c7f9e1 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py @@ -474,5 +474,5 @@ def test_depth(self): assert images[0].shape == (512, 512, 3) original_image = images[0, -3:, -3:, -1].flatten() - expected_image = np.array([0.4082, 0.3879, 0.2781, 0.2655, 0.327, 0.372, 0.3762, 0.3444, 0.3122]) + expected_image = np.array([0.4082, 0.3880, 0.2779, 0.2654, 0.327, 0.372, 0.3761, 0.3442, 0.3122]) assert np.allclose(original_image, expected_image, atol=1e-04) From 42920a976e2ee76f3eed56eecb68401d770dc2db Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Tue, 30 Jan 2024 16:55:48 +0100 Subject: [PATCH 37/75] Moved changes to UNetMidBlock2DCrossAttn --- src/diffusers/models/unet_2d_blocks.py | 92 -------------------- src/diffusers/models/unets/unet_2d_blocks.py | 27 ++++-- 2 files changed, 18 insertions(+), 101 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index cbc9d8714595..4628edfb7e0f 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -302,98 +302,6 @@ class SimpleCrossAttnDownBlock2D(SimpleCrossAttnDownBlock2D): deprecation_message = "Importing `SimpleCrossAttnDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import SimpleCrossAttnDownBlock2D`, instead." deprecate("SimpleCrossAttnDownBlock2D", "0.29", deprecation_message) -class UNetMidBlock2DCrossAttn(nn.Module): - def __init__( - self, - in_channels: int, - temb_channels: int, - out_channels: Optional[int] = None, - dropout: float = 0.0, - num_layers: int = 1, - transformer_layers_per_block: Union[int, Tuple[int]] = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_groups_out: Optional[int] = None, - resnet_pre_norm: bool = True, - num_attention_heads: int = 1, - output_scale_factor: float = 1.0, - cross_attention_dim: int = 1280, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - upcast_attention: bool = False, - attention_type: str = "default", - ): - super().__init__() - - out_channels = out_channels or in_channels - self.in_channels = in_channels - self.out_channels = out_channels - - self.has_cross_attention = True - self.num_attention_heads = num_attention_heads - resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - resnet_groups_out = resnet_groups_out or resnet_groups - - # there is always at least one resnet - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - groups_out=resnet_groups_out, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] - - for i in range(num_layers): - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=transformer_layers_per_block[i], - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups_out, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - attention_type=attention_type, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - resnets.append( - ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups_out, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) class KDownBlock2D(KDownBlock2D): deprecation_message = "Importing `KDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import KDownBlock2D`, instead." diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index 3796896ef675..8b5943d94617 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -671,6 +671,7 @@ def __init__( self, in_channels: int, temb_channels: int, + out_channels: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1, @@ -678,6 +679,7 @@ def __init__( resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, + resnet_groups_out: Optional[int] = None, resnet_pre_norm: bool = True, num_attention_heads: int = 1, output_scale_factor: float = 1.0, @@ -689,6 +691,10 @@ def __init__( ): super().__init__() + out_channels = out_channels or in_channels + self.in_channels = in_channels + self.out_channels = out_channels + self.has_cross_attention = True self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) @@ -697,14 +703,17 @@ def __init__( if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * num_layers + resnet_groups_out = resnet_groups_out or resnet_groups + # there is always at least one resnet resnets = [ ResnetBlock2D( in_channels=in_channels, - out_channels=in_channels, + out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, + groups_out=resnet_groups_out, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, @@ -719,11 +728,11 @@ def __init__( attentions.append( Transformer2DModel( num_attention_heads, - in_channels // num_attention_heads, - in_channels=in_channels, + out_channels // num_attention_heads, + in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, + norm_num_groups=resnet_groups_out, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, attention_type=attention_type, @@ -733,8 +742,8 @@ def __init__( attentions.append( DualTransformer2DModel( num_attention_heads, - in_channels // num_attention_heads, - in_channels=in_channels, + out_channels // num_attention_heads, + in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, @@ -742,11 +751,11 @@ def __init__( ) resnets.append( ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, + in_channels=out_channels, + out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, - groups=resnet_groups, + groups=resnet_groups_out, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, From b14b70cf580b523d1be8a79386e19c8b39e9d1ed Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Tue, 30 Jan 2024 17:08:01 +0100 Subject: [PATCH 38/75] tiny cleanups --- docs/source/en/_toctree.yml | 4 ++++ src/diffusers/models/attention.py | 1 - src/diffusers/models/unet_2d_blocks.py | 2 ++ src/diffusers/pipelines/controlnet/pipeline_controlnet.py | 4 ---- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 9647d92754dc..4cbe5a0fcc52 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -272,6 +272,10 @@ title: ControlNet - local: api/pipelines/controlnet_sdxl title: ControlNet with Stable Diffusion XL + - local: api/pipelines/controlnetxs + title: ControlNet-XS + - local: api/pipelines/controlnetxs_sdxl + title: ControlNet-XS with Stable Diffusion XL - local: api/pipelines/dance_diffusion title: Dance Diffusion - local: api/pipelines/ddim diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index fc4564c3a6ff..804c34d617d3 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -332,7 +332,6 @@ def forward( attention_mask=attention_mask, **cross_attention_kwargs, ) - if self.use_ada_layer_norm_zero: attn_output = gate_msa.unsqueeze(1) * attn_output elif self.use_ada_layer_norm_single: diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 4628edfb7e0f..497eabfc607b 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -298,6 +298,7 @@ class ResnetDownsampleBlock2D(ResnetDownsampleBlock2D): deprecation_message = "Importing `ResnetDownsampleBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import ResnetDownsampleBlock2D`, instead." deprecate("ResnetDownsampleBlock2D", "0.29", deprecation_message) + class SimpleCrossAttnDownBlock2D(SimpleCrossAttnDownBlock2D): deprecation_message = "Importing `SimpleCrossAttnDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import SimpleCrossAttnDownBlock2D`, instead." deprecate("SimpleCrossAttnDownBlock2D", "0.29", deprecation_message) @@ -307,6 +308,7 @@ class KDownBlock2D(KDownBlock2D): deprecation_message = "Importing `KDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import KDownBlock2D`, instead." deprecate("KDownBlock2D", "0.29", deprecation_message) + class KCrossAttnDownBlock2D(KCrossAttnDownBlock2D): deprecation_message = "Importing `KCrossAttnDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import KCrossAttnDownBlock2D`, instead." deprecate("KCrossAttnDownBlock2D", "0.29", deprecation_message) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index db2d5416dfa2..ae9a28590151 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -1205,10 +1205,6 @@ def __call__( controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] - print( - f"Denoising step {i} > Right before controlnet application : Device type of controlnet >> ", - self.controlnet.device.type, - ) down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, From c163fcb9aa545b8cf2908e71bdeba1eb7dd8ab7a Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Tue, 30 Jan 2024 17:09:59 +0100 Subject: [PATCH 39/75] Removed stray prints --- src/diffusers/pipelines/controlnet/pipeline_controlnet.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index ae9a28590151..6cd1658c59a3 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -1214,10 +1214,6 @@ def __call__( guess_mode=guess_mode, return_dict=False, ) - print( - f"Denoising step {i} > Right after controlnet application : Device type of controlnet >> ", - self.controlnet.device.type, - ) if guess_mode and self.do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. @@ -1251,7 +1247,6 @@ def __call__( callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] - print("btw, calling callback_on_step_end") callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) @@ -1262,7 +1257,6 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: - print("btw, calling callback") step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) From 145037dfed8a085bfb3b72b742869ce291587b3b Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Wed, 31 Jan 2024 10:01:34 +0100 Subject: [PATCH 40/75] Removed ip adapters + freeU - Removed ip adapters + freeU as they don't make sense for ControlNet-XS - Fixed imports of UNet components --- src/diffusers/models/controlnet_xs.py | 4 +- .../controlnet_xs/pipeline_controlnet_xs.py | 91 +++---------------- .../pipeline_controlnet_xs_sd_xl.py | 69 ++------------ .../versatile_diffusion/modeling_text_unet.py | 3 +- .../controlnet_xs/test_controlnetxs.py | 1 - .../controlnet_xs/test_controlnetxs_sdxl.py | 1 - 6 files changed, 29 insertions(+), 140 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 0f462d6e3ab1..b846faf6e7ad 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -15,8 +15,8 @@ Timesteps, ) from .modeling_utils import ModelMixin -from .unet_2d_blocks import Downsample2D, ResnetBlock2D, Transformer2DModel, UNetMidBlock2DCrossAttn, Upsample2D -from .unet_2d_condition import UNet2DConditionModel +from .unets.unet_2d_blocks import Downsample2D, ResnetBlock2D, Transformer2DModel, UNetMidBlock2DCrossAttn, Upsample2D +from .unets.unet_2d_condition import UNet2DConditionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index aee2a5123a82..9964fd767209 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -19,11 +19,11 @@ import PIL.Image import torch import torch.nn.functional as F -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ControlNetXSAddon, ControlNetXSModel, ImageProjection +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetXSAddon, ControlNetXSModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -88,7 +88,7 @@ class StableDiffusionControlNetXSPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for text-to-image generation using Stable Diffusion with ControlNet-XS guidance. @@ -100,7 +100,6 @@ class StableDiffusionControlNetXSPipeline( - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights - - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters - [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files Args: @@ -123,8 +122,8 @@ class StableDiffusionControlNetXSPipeline( A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ - model_cpu_offload_seq = "text_encoder->image_encoder->controlnet->vae" - _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + model_cpu_offload_seq = "text_encoder->controlnet->vae" + _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] @@ -137,7 +136,6 @@ def __init__( scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, - image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): super().__init__() @@ -176,7 +174,6 @@ def __init__( scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, - image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) @@ -205,7 +202,15 @@ def from_pretrained(cls, base_path, addon_path, base_kwargs={}, addon_kwargs={}) controlnet_addon = ControlNetXSAddon.from_pretrained(addon_path, **addon_kwargs) unet = components["unet"] - components = {k: v for k, v in components.items() if k != "unet"} + + to_ignore = ["image_encoder"] + for item in to_ignore: + if item in components: + print( + f"Loaded base pipeline has component `{item}` which StableDiffusionControlNetXSPipeline can't use. It will be ignored." + ) + + components = {k: v for k, v in components.items() if k not in ["unet"] + to_ignore} controlnet = ControlNetXSModel(unet, controlnet_addon) return StableDiffusionControlNetXSPipeline(controlnet=controlnet, **components) @@ -484,31 +489,6 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -723,32 +703,6 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents - def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): - r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. - - The suffixes after the scaling factors represent the stages where they are being applied. - - Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values - that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. - - Args: - s1 (`float`): - Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to - mitigate "oversmoothing effect" in the enhanced denoising process. - s2 (`float`): - Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to - mitigate "oversmoothing effect" in the enhanced denoising process. - b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. - b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. - """ - if not hasattr(self, "unet"): - raise ValueError("The pipeline must have `unet` for using FreeU.") - self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) - - def disable_freeu(self): - """Disables the FreeU mechanism if enabled.""" - self.unet.disable_freeu() - @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale def guidance_scale(self): @@ -791,7 +745,6 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -848,8 +801,6 @@ def __call__( negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. - ip_adapter_image (`PipelineImageInput`, *optional*): - Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -961,14 +912,6 @@ def __call__( if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - if ip_adapter_image is not None: - output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt, output_hidden_state - ) - if self.do_classifier_free_guidance: - image_embeds = torch.cat([negative_image_embeds, image_embeds]) - # 4. Prepare image if isinstance(controlnet, ControlNetXSModel): image = self.prepare_image( @@ -1005,9 +948,6 @@ def __call__( # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7.1 Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None - # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) @@ -1034,7 +974,6 @@ def __call__( controlnet_cond=image, conditioning_scale=controlnet_conditioning_scale, cross_attention_kwargs=cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, return_dict=True, do_control=do_control, ).sample diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index a204be74ed3e..a935b7dbbc78 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -24,7 +24,6 @@ CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, - CLIPVisionModelWithProjection, ) from diffusers.utils.import_utils import is_invisible_watermark_available @@ -148,14 +147,13 @@ class StableDiffusionXLControlNetXSPipeline( watermarker is used. """ - model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->controlnet->vae" + model_cpu_offload_seq = "text_encoder->text_encoder_2->controlnet->vae" _optional_components = [ "tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2", "feature_extractor", - "image_encoder", ] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] @@ -171,7 +169,6 @@ def __init__( force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, feature_extractor: CLIPImageProcessor = None, - image_encoder: CLIPVisionModelWithProjection = None, ): super().__init__() @@ -194,7 +191,6 @@ def __init__( controlnet=controlnet, scheduler=scheduler, feature_extractor=feature_extractor, - image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) @@ -230,7 +226,15 @@ def from_pretrained(cls, base_path, addon_path, base_kwargs={}, addon_kwargs={}) controlnet_addon = ControlNetXSAddon.from_pretrained(addon_path, **addon_kwargs) unet = components["unet"] - components = {k: v for k, v in components.items() if k != "unet"} + + to_ignore = ["image_encoder"] + for item in to_ignore: + if item in components: + print( + f"Loaded base pipeline has component `{item}` which StableDiffusionControlNetXSPipeline can't use. It will be ignored." + ) + + components = {k: v for k, v in components.items() if k not in ["unet"] + to_ignore} controlnet = ControlNetXSModel(unet, controlnet_addon) return StableDiffusionXLControlNetXSPipeline(controlnet=controlnet, **components) @@ -530,31 +534,6 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -808,34 +787,6 @@ def upcast_vae(self): self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) - def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): - r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. - - The suffixes after the scaling factors represent the stages where they are being applied. - - Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values - that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. - - Args: - s1 (`float`): - Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to - mitigate "oversmoothing effect" in the enhanced denoising process. - s2 (`float`): - Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to - mitigate "oversmoothing effect" in the enhanced denoising process. - b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. - b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. - """ - if not hasattr(self, "unet"): - raise ValueError("The pipeline must have `unet` for using FreeU.") - # todo: check if works - self.controlnet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) - - def disable_freeu(self): - """Disables the FreeU mechanism if enabled.""" - # todo: check if works - self.controlnet.disable_freeu() - @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale def guidance_scale(self): diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index ab7bb4ab163f..ac37a8df5e3e 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -2260,12 +2260,13 @@ def __init__( self.has_cross_attention = True self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - resnet_groups_out = resnet_groups_out or resnet_groups # support for variable transformer layers per block if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * num_layers + resnet_groups_out = resnet_groups_out or resnet_groups + # there is always at least one resnet resnets = [ ResnetBlockFlat( diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py index 49330dceb64e..45615b14dd4f 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py @@ -182,7 +182,6 @@ def get_dummy_components(self, time_cond_proj_dim=None): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, - "image_encoder": None, } return components diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py index 173b96c7f9e1..1d403c1f2a4f 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py @@ -142,7 +142,6 @@ def get_dummy_components(self): "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, "feature_extractor": None, - "image_encoder": None, } return components From 0be0d6fa128be84f6835f801e7f8a9bcb9a41b1f Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Wed, 31 Jan 2024 14:42:49 +0100 Subject: [PATCH 41/75] Fixed test_save_load_float16 --- src/diffusers/models/controlnet_xs.py | 13 +++++ .../controlnet_xs/test_controlnetxs.py | 46 ++++++++++++++++++ .../controlnet_xs/test_controlnetxs_sdxl.py | 48 +++++++++++++++++++ 3 files changed, 107 insertions(+) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index b846faf6e7ad..1dc92d2a9739 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -1,3 +1,16 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import math from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py index 45615b14dd4f..9131de9cb88e 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py @@ -327,6 +327,52 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4): max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() self.assertLess(max_diff, expected_max_difference) + @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") + def test_save_load_float16(self, expected_max_diff=1e-2): + components = self.get_dummy_components() + for name, module in components.items(): + if hasattr(module, "half"): + components[name] = module.to(torch_device).half() + + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir_components: + with tempfile.TemporaryDirectory() as tmpdir_addon: + pipe.save_pretrained( + base_path=tmpdir_components, + addon_path=tmpdir_addon, + base_kwargs={"safe_serialization": False}, + addon_kwargs={"safe_serialization": False}, + ) + + pipe_loaded = self.pipeline_class.from_pretrained(base_path=tmpdir_components, addon_path=tmpdir_addon) + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + for name, component in pipe_loaded.components.items(): + if hasattr(component, "dtype"): + self.assertTrue( + component.dtype == torch.float16, + f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.", + ) + + inputs = self.get_dummy_inputs(torch_device) + output_loaded = pipe_loaded(**inputs)[0] + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() + self.assertLess( + max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading." + ) @slow @require_torch_gpu diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py index 1d403c1f2a4f..56c448115c4e 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py @@ -425,6 +425,54 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4): max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() self.assertLess(max_diff, expected_max_difference) + # copied from test_controlnetxs.py + @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") + def test_save_load_float16(self, expected_max_diff=1e-2): + components = self.get_dummy_components() + for name, module in components.items(): + if hasattr(module, "half"): + components[name] = module.to(torch_device).half() + + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir_components: + with tempfile.TemporaryDirectory() as tmpdir_addon: + pipe.save_pretrained( + base_path=tmpdir_components, + addon_path=tmpdir_addon, + base_kwargs={"safe_serialization": False}, + addon_kwargs={"safe_serialization": False}, + ) + + pipe_loaded = self.pipeline_class.from_pretrained(base_path=tmpdir_components, addon_path=tmpdir_addon) + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + for name, component in pipe_loaded.components.items(): + if hasattr(component, "dtype"): + self.assertTrue( + component.dtype == torch.float16, + f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.", + ) + + inputs = self.get_dummy_inputs(torch_device) + output_loaded = pipe_loaded(**inputs)[0] + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() + self.assertLess( + max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading." + ) + @slow @require_torch_gpu From a21479ea628b461311e61143213227e9f1481d64 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Thu, 1 Feb 2024 15:22:55 +0100 Subject: [PATCH 42/75] Make style, quality, fix-copies --- src/diffusers/utils/dummy_pt_objects.py | 3 +++ tests/pipelines/controlnet_xs/test_controlnetxs.py | 1 + 2 files changed, 4 insertions(+) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 1bf3806c8776..c7fa78445982 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -108,6 +108,8 @@ def from_pretrained(cls, *args, **kwargs): class ControlNetXSModel(metaclass=DummyObject): + _backends = ["torch"] + def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @@ -119,6 +121,7 @@ def from_config(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) + class I2VGenXLUNet(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py index 9131de9cb88e..0684ec770370 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py @@ -374,6 +374,7 @@ def test_save_load_float16(self, expected_max_diff=1e-2): max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading." ) + @slow @require_torch_gpu class ControlNetXSPipelineSlowTests(unittest.TestCase): From e2d009cd27734ca9d0841f8575d701c93a587efa Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Thu, 29 Feb 2024 19:53:09 +0100 Subject: [PATCH 43/75] Changed loading/saving API for ControlNetXS - Changed loading/saving API for ControlNetXS - other small fixes --- src/diffusers/models/controlnet_xs.py | 10 +++- .../controlnet_xs/pipeline_controlnet_xs.py | 55 ++++++++---------- .../pipeline_controlnet_xs_sd_xl.py | 56 ++++++++----------- .../controlnet_xs/test_controlnetxs.py | 22 +++----- .../controlnet_xs/test_controlnetxs_sdxl.py | 22 +++----- 5 files changed, 67 insertions(+), 98 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 1dc92d2a9739..5176044187da 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -266,7 +266,11 @@ def __init__( time_embedding_input_dim: Optional[int] = 320, time_embedding_dim: Optional[int] = 1280, learn_time_embedding: bool = False, - channels_base: Dict[str, List[Tuple[int]]] = gather_base_subblock_sizes((320, 640, 1280, 1280)), + channels_base: Dict[str, List[Tuple[int]]] = { + "down - out": [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280], + "mid - out": 1280, + "up - in": [1280, 1280, 1280, 1280, 1280, 1280, 1280, 640, 640, 640, 320, 320], + }, attention_head_dim: Union[int, Tuple[int]] = 4, block_out_channels: Tuple[int] = (4, 8, 16, 16), cross_attention_dim: int = 1024, @@ -462,8 +466,8 @@ class ControlNetXSModel(nn.Module): ctrl_addon (`ControlNetXSAddon`): The control addon. time_embedding_mix (`float`, defaults to 1.0): - If 0, then only the base model's time embedding is be used. - If 1, then only the control model's time embedding is be used. + If 0, then only the base model's time embedding is used. + If 1, then only the control model's time embedding is used. Otherwise, both are combined. """ diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index 9964fd767209..f3f560ac30e4 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -23,7 +23,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ControlNetXSAddon, ControlNetXSModel +from ...models import AutoencoderKL, ControlNetXSModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -67,10 +67,12 @@ >>> # initialize the models and pipeline >>> controlnet_conditioning_scale = 0.5 + >>> controlnet_xs_addon = ControlNetXSAddon.from_pretrained( + ... "UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16 + ... ) >>> pipe = StableDiffusionControlNetXSPipeline.from_pretrained( - >>> base_path="stabilityai/stable-diffusion-2-1-base", base_kwargs=dict(torch_dtype=torch.float16), - >>> addon_path="UmerHA/Testing-ConrolNetXS-SD2.1-canny", addon_kwargs=dict(torch_dtype=torch.float16), - >>> ) + ... "stabilityai/stable-diffusion-2-1-base", controlnet_xs_addon=controlnet_xs_addon, torch_dtype=torch.float16 + ... ) >>> pipe.enable_model_cpu_offload() >>> # get canny image @@ -183,23 +185,20 @@ def __init__( self.register_to_config(requires_safety_checker=requires_safety_checker) @classmethod - def from_pretrained(cls, base_path, addon_path, base_kwargs={}, addon_kwargs={}): + def from_pretrained(cls, base_path, controlnet_addon, **kwargs): """ Instantiates pipeline from a `StableDiffusionPipeline` and a `ControlNetXSAddon`. Arguments: base_path (`str` or `os.PathLike`): Directory to load underlying `StableDiffusionPipeline` from. - addon_path (`str` or `os.PathLike`): - Directory to load underlying `ControlNetXSAddon` model from. - base_kwargs (`Dict[str, Any]`, *optional*): + controlnet_addon (`ControlNetXSAddon`): + A `ControlNetXSAddon` model. + kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to the [`~StableDiffusionPipeline.from_pretrained`] method. - addon_kwargs (`Dict[str, Any]`, *optional*): - Additional keyword arguments passed along to the [`~ControlNetXSAddon.from_pretrained`] method. """ - components = StableDiffusionPipeline.from_pretrained(base_path, **base_kwargs).components - controlnet_addon = ControlNetXSAddon.from_pretrained(addon_path, **addon_kwargs) + components = StableDiffusionPipeline.from_pretrained(base_path, **kwargs).components unet = components["unet"] @@ -215,31 +214,21 @@ def from_pretrained(cls, base_path, addon_path, base_kwargs={}, addon_kwargs={}) controlnet = ControlNetXSModel(unet, controlnet_addon) return StableDiffusionControlNetXSPipeline(controlnet=controlnet, **components) - def save_pretrained(self, base_path, addon_path, base_kwargs={}, addon_kwargs={}): - """ - - Separately save the underlying `StableDiffusionPipeline` and the `ControlNetXSAddon` so the pipeline is easily reloaded using the - [`~StableDiffusionControlNetXSPipeline.from_pretrained`] class method. - - Arguments: - base_path (`str` or `os.PathLike`): - Directory to save underlying `StableDiffusionPipeline` to. Will be created if it doesn't exist. - addon_path (`str` or `os.PathLike`): - Directory to save underlying `ControlNetXSAddon` model to. Will be created if it doesn't exist. - base_kwargs (`Dict[str, Any]`, *optional*): - Additional keyword arguments passed along to the [`~StableDiffusionPipeline.save_pretrained`] method. - addon_kwargs (`Dict[str, Any]`, *optional*): - Additional keyword arguments passed along to the [`~ControlNetXSAddon.save_pretrained`] method. - - """ + def save_pretrained(self, *args, **kwargs): + raise EnvironmentError( + "Save the underlying `StableDiffusionPipeline` and the `ControlNetXSAddon` separately" + " by using `pipe.get_base_pipeline().save_pretrained()` and `pipe.get_controlnet_addon().save_pretrained()`." + ) + def get_base_pipeline(self): + """Get underlying `StableDiffusionPipeline` without the `ControlNetXSAddon` model.""" components = {k: v for k, v in self.components.items() if k != "controlnet"} components["unet"] = self.components["controlnet"].base_model + return StableDiffusionPipeline(**components) - controlnet_addon = self.components["controlnet"].ctrl_addon - - StableDiffusionPipeline(**components).save_pretrained(base_path, **base_kwargs) - controlnet_addon.save_pretrained(addon_path, **addon_kwargs) + def get_controlnet_addon(self): + """Get the `ControlNetXSAddon` model.""" + return self.components["controlnet"].ctrl_addon # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index a935b7dbbc78..a91e68611c7b 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -30,7 +30,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ControlNetXSAddon, ControlNetXSModel +from ...models import AutoencoderKL, ControlNetXSModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -83,11 +83,12 @@ >>> # initialize the models and pipeline >>> controlnet_conditioning_scale = 0.5 # recommended for good generalization >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) + >>> controlnet_xs_addon = ControlNetXSAddon.from_pretrained( + ... "UmerHA/Testing-ConrolNetXS-SDXL-canny", torch_dtype=torch.float16 + ... ) >>> pipe = StableDiffusionControlNetXSPipeline.from_pretrained( - >>> base_path="stabilityai/stable-diffusion-xl-base-1.0", base_kwargs=dict(vae=vae, torch_dtype=torch.float16), - >>> addon_path="UmerHA/Testing-ConrolNetXS-SDXL-canny", addon_kwargs=dict(torch_dtype=torch.float16), - >>> ) - + ... base_path="stabilityai/stable-diffusion-xl-base-1.0", controlnet_xs_addon=controlnet_xs_addon, torch_dtype=torch.float16 + ... ) >>> pipe.enable_model_cpu_offload() >>> # get canny image @@ -207,23 +208,20 @@ def __init__( self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) @classmethod - def from_pretrained(cls, base_path, addon_path, base_kwargs={}, addon_kwargs={}): + def from_pretrained(cls, base_path, controlnet_addon, **kwargs): """ Instantiates pipeline from a `StableDiffusionXLPipeline` and a `ControlNetXSAddon`. Arguments: base_path (`str` or `os.PathLike`): Directory to load underlying `StableDiffusionXLPipeline` from. - addon_path (`str` or `os.PathLike`): - Directory to load underlying `ControlNetXSAddon` model from. - base_kwargs (`Dict[str, Any]`, *optional*): + controlnet_addon (`ControlNetXSAddon`): + A `ControlNetXSAddon` model. + kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to the [`~StableDiffusionXLPipeline.from_pretrained`] method. - addon_kwargs (`Dict[str, Any]`, *optional*): - Additional keyword arguments passed along to the [`~ControlNetXSAddon.from_pretrained`] method. """ - components = StableDiffusionXLPipeline.from_pretrained(base_path, **base_kwargs).components - controlnet_addon = ControlNetXSAddon.from_pretrained(addon_path, **addon_kwargs) + components = StableDiffusionXLPipeline.from_pretrained(base_path, **kwargs).components unet = components["unet"] @@ -239,31 +237,21 @@ def from_pretrained(cls, base_path, addon_path, base_kwargs={}, addon_kwargs={}) controlnet = ControlNetXSModel(unet, controlnet_addon) return StableDiffusionXLControlNetXSPipeline(controlnet=controlnet, **components) - def save_pretrained(self, base_path, addon_path, base_kwargs={}, addon_kwargs={}): - """ - - Separately save the underlying `StableDiffusionXLPipeline` and the `ControlNetXSAddon` so the pipeline is easily reloaded using the - [`~StableDiffusionControlNetXSPipeline.from_pretrained`] class method. - - Arguments: - base_path (`str` or `os.PathLike`): - Directory to save underlying `StableDiffusionXLPipeline` to. Will be created if it doesn't exist. - addon_path (`str` or `os.PathLike`): - Directory to save underlying `ControlNetXSAddon` model to. Will be created if it doesn't exist. - base_kwargs (`Dict[str, Any]`, *optional*): - Additional keyword arguments passed along to the [`~StableDiffusionXLPipeline.save_pretrained`] method. - addon_kwargs (`Dict[str, Any]`, *optional*): - Additional keyword arguments passed along to the [`~ControlNetXSAddon.save_pretrained`] method. - - """ + def save_pretrained(self, *args, **kwargs): + raise EnvironmentError( + "Save the underlying `StableDiffusionXLPipeline` and the `ControlNetXSAddon` separately" + " by using `pipe.get_base_pipeline().save_pretrained()` and `pipe.get_controlnet_addon().save_pretrained()`." + ) + def get_base_pipeline(self): + """Get underlying `StableDiffusionXLPipeline` without the `ControlNetXSAddon` model.""" components = {k: v for k, v in self.components.items() if k != "controlnet"} components["unet"] = self.components["controlnet"].base_model + return StableDiffusionXLPipeline(**components) - controlnet_addon = self.components["controlnet"].ctrl_addon - - StableDiffusionXLPipeline(**components).save_pretrained(base_path, **base_kwargs) - controlnet_addon.save_pretrained(addon_path, **addon_kwargs) + def get_controlnet_addon(self): + """Get the `ControlNetXSAddon` model.""" + return self.components["controlnet"].ctrl_addon # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py index 0684ec770370..f8fba1f96c8e 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py @@ -259,14 +259,11 @@ def test_save_load_local(self, expected_max_difference=5e-4): with tempfile.TemporaryDirectory() as tmpdir_components: with tempfile.TemporaryDirectory() as tmpdir_addon: - pipe.save_pretrained( - base_path=tmpdir_components, - addon_path=tmpdir_addon, - base_kwargs={"safe_serialization": False}, - addon_kwargs={"safe_serialization": False}, - ) + pipe.get_base_pipeline().save_pretrained(tmpdir_components, safe_serialization=False) + pipe.get_controlnet_addon().save_pretrained(tmpdir_addon, safe_serialization=False) - pipe_loaded = self.pipeline_class.from_pretrained(base_path=tmpdir_components, addon_path=tmpdir_addon) + addon_loaded = ControlNetXSAddon.from_pretrained(tmpdir_addon) + pipe_loaded = self.pipeline_class.from_pretrained(base_path=tmpdir_components, controlnet_addon=addon_loaded) for component in pipe_loaded.components.values(): if hasattr(component, "set_default_attn_processor"): @@ -300,14 +297,11 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4): with tempfile.TemporaryDirectory() as tmpdir_components: with tempfile.TemporaryDirectory() as tmpdir_addon: - pipe.save_pretrained( - base_path=tmpdir_components, - addon_path=tmpdir_addon, - base_kwargs={"safe_serialization": False}, - addon_kwargs={"safe_serialization": False}, - ) + pipe.get_base_pipeline().save_pretrained(tmpdir_components, safe_serialization=False) + pipe.get_controlnet_addon().save_pretrained(tmpdir_addon, safe_serialization=False) - pipe_loaded = self.pipeline_class.from_pretrained(base_path=tmpdir_components, addon_path=tmpdir_addon) + addon_loaded = ControlNetXSAddon.from_pretrained(tmpdir_addon) + pipe_loaded = self.pipeline_class.from_pretrained(base_path=tmpdir_components, controlnet_addon=addon_loaded) for component in pipe_loaded.components.values(): if hasattr(component, "set_default_attn_processor"): diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py index 56c448115c4e..e854ede94259 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py @@ -328,14 +328,11 @@ def test_save_load_local(self, expected_max_difference=5e-4): with tempfile.TemporaryDirectory() as tmpdir_components: with tempfile.TemporaryDirectory() as tmpdir_addon: - pipe.save_pretrained( - base_path=tmpdir_components, - addon_path=tmpdir_addon, - base_kwargs={"safe_serialization": False}, - addon_kwargs={"safe_serialization": False}, - ) + pipe.get_base_pipeline().save_pretrained(tmpdir_components, safe_serialization=False) + pipe.get_controlnet_addon().save_pretrained(tmpdir_addon, safe_serialization=False) - pipe_loaded = self.pipeline_class.from_pretrained(base_path=tmpdir_components, addon_path=tmpdir_addon) + addon_loaded = ControlNetXSAddon.from_pretrained(tmpdir_addon) + pipe_loaded = self.pipeline_class.from_pretrained(base_path=tmpdir_components, controlnet_addon=addon_loaded) for component in pipe_loaded.components.values(): if hasattr(component, "set_default_attn_processor"): @@ -390,14 +387,11 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4): with tempfile.TemporaryDirectory() as tmpdir_components: with tempfile.TemporaryDirectory() as tmpdir_addon: - pipe.save_pretrained( - base_path=tmpdir_components, - addon_path=tmpdir_addon, - base_kwargs={"safe_serialization": False}, - addon_kwargs={"safe_serialization": False}, - ) + pipe.get_base_pipeline().save_pretrained(tmpdir_components, safe_serialization=False) + pipe.get_controlnet_addon().save_pretrained(tmpdir_addon, safe_serialization=False) - pipe_loaded = self.pipeline_class.from_pretrained(base_path=tmpdir_components, addon_path=tmpdir_addon) + addon_loaded = ControlNetXSAddon.from_pretrained(tmpdir_addon) + pipe_loaded = self.pipeline_class.from_pretrained(base_path=tmpdir_components, controlnet_addon=addon_loaded) for component in pipe_loaded.components.values(): if hasattr(component, "set_default_attn_processor"): From 2bca2caf02dcb776104f81e32c9b98051bc6b09d Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Thu, 29 Feb 2024 19:53:44 +0100 Subject: [PATCH 44/75] Removed ControlNet-XS from research examples --- .../research_projects/controlnetxs/README.md | 16 - .../controlnetxs/README_sdxl.md | 15 - .../controlnetxs/controlnetxs.py | 1014 ---------------- .../controlnetxs/infer_sd_controlnetxs.py | 58 - .../controlnetxs/infer_sdxl_controlnetxs.py | 57 - .../controlnetxs/pipeline_controlnet_xs.py | 901 -------------- .../pipeline_controlnet_xs_sd_xl.py | 1073 ----------------- 7 files changed, 3134 deletions(-) delete mode 100644 examples/research_projects/controlnetxs/README.md delete mode 100644 examples/research_projects/controlnetxs/README_sdxl.md delete mode 100644 examples/research_projects/controlnetxs/controlnetxs.py delete mode 100644 examples/research_projects/controlnetxs/infer_sd_controlnetxs.py delete mode 100644 examples/research_projects/controlnetxs/infer_sdxl_controlnetxs.py delete mode 100644 examples/research_projects/controlnetxs/pipeline_controlnet_xs.py delete mode 100644 examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py diff --git a/examples/research_projects/controlnetxs/README.md b/examples/research_projects/controlnetxs/README.md deleted file mode 100644 index 72ed91c01db2..000000000000 --- a/examples/research_projects/controlnetxs/README.md +++ /dev/null @@ -1,16 +0,0 @@ -# ControlNet-XS - -ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results. - -Like the original ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process. - -ControlNet-XS generates images with comparable quality to a regular ControlNet, but it is 20-25% faster ([see benchmark](https://github.com/UmerHA/controlnet-xs-benchmark/blob/main/Speed%20Benchmark.ipynb) with StableDiffusion-XL) and uses ~45% less memory. - -Here's the overview from the [project page](https://vislearn.github.io/ControlNet-XS/): - -*With increasing computing capabilities, current model architectures appear to follow the trend of simply upscaling all components without validating the necessity for doing so. In this project we investigate the size and architectural design of ControlNet [Zhang et al., 2023] for controlling the image generation process with stable diffusion-based models. We show that a new architecture with as little as 1% of the parameters of the base model achieves state-of-the art results, considerably better than ControlNet in terms of FID score. Hence we call it ControlNet-XS. We provide the code for controlling StableDiffusion-XL [Podell et al., 2023] (Model B, 48M Parameters) and StableDiffusion 2.1 [Rombach et al. 2022] (Model B, 14M Parameters), all under openrail license.* - -This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️ - - -> 🧠 Make sure to check out the Schedulers [guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines. \ No newline at end of file diff --git a/examples/research_projects/controlnetxs/README_sdxl.md b/examples/research_projects/controlnetxs/README_sdxl.md deleted file mode 100644 index d401c1e76698..000000000000 --- a/examples/research_projects/controlnetxs/README_sdxl.md +++ /dev/null @@ -1,15 +0,0 @@ -# ControlNet-XS with Stable Diffusion XL - -ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results. - -Like the original ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process. - -ControlNet-XS generates images with comparable quality to a regular ControlNet, but it is 20-25% faster ([see benchmark](https://github.com/UmerHA/controlnet-xs-benchmark/blob/main/Speed%20Benchmark.ipynb)) and uses ~45% less memory. - -Here's the overview from the [project page](https://vislearn.github.io/ControlNet-XS/): - -*With increasing computing capabilities, current model architectures appear to follow the trend of simply upscaling all components without validating the necessity for doing so. In this project we investigate the size and architectural design of ControlNet [Zhang et al., 2023] for controlling the image generation process with stable diffusion-based models. We show that a new architecture with as little as 1% of the parameters of the base model achieves state-of-the art results, considerably better than ControlNet in terms of FID score. Hence we call it ControlNet-XS. We provide the code for controlling StableDiffusion-XL [Podell et al., 2023] (Model B, 48M Parameters) and StableDiffusion 2.1 [Rombach et al. 2022] (Model B, 14M Parameters), all under openrail license.* - -This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️ - -> 🧠 Make sure to check out the Schedulers [guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines. \ No newline at end of file diff --git a/examples/research_projects/controlnetxs/controlnetxs.py b/examples/research_projects/controlnetxs/controlnetxs.py deleted file mode 100644 index 14ad1d8a3af9..000000000000 --- a/examples/research_projects/controlnetxs/controlnetxs.py +++ /dev/null @@ -1,1014 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import math -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import functional as F -from torch.nn.modules.normalization import GroupNorm - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.models.attention_processor import USE_PEFT_BACKEND, AttentionProcessor -from diffusers.models.autoencoders import AutoencoderKL -from diffusers.models.lora import LoRACompatibleConv -from diffusers.models.modeling_utils import ModelMixin -from diffusers.models.unets.unet_2d_blocks import ( - CrossAttnDownBlock2D, - CrossAttnUpBlock2D, - DownBlock2D, - Downsample2D, - ResnetBlock2D, - Transformer2DModel, - UpBlock2D, - Upsample2D, -) -from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel -from diffusers.utils import BaseOutput, logging - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -@dataclass -class ControlNetXSOutput(BaseOutput): - """ - The output of [`ControlNetXSModel`]. - - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - The output of the `ControlNetXSModel`. Unlike `ControlNetOutput` this is NOT to be added to the base model - output, but is already the final output. - """ - - sample: torch.FloatTensor = None - - -# copied from diffusers.models.controlnet.ControlNetConditioningEmbedding -class ControlNetConditioningEmbedding(nn.Module): - """ - Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN - [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized - training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the - convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides - (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full - model) to encode image-space conditions ... into feature maps ..." - """ - - def __init__( - self, - conditioning_embedding_channels: int, - conditioning_channels: int = 3, - block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), - ): - super().__init__() - - self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) - - self.blocks = nn.ModuleList([]) - - for i in range(len(block_out_channels) - 1): - channel_in = block_out_channels[i] - channel_out = block_out_channels[i + 1] - self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) - self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) - - self.conv_out = zero_module( - nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) - ) - - def forward(self, conditioning): - embedding = self.conv_in(conditioning) - embedding = F.silu(embedding) - - for block in self.blocks: - embedding = block(embedding) - embedding = F.silu(embedding) - - embedding = self.conv_out(embedding) - - return embedding - - -class ControlNetXSModel(ModelMixin, ConfigMixin): - r""" - A ControlNet-XS model - - This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic - methods implemented for all models (such as downloading or saving). - - Most of parameters for this model are passed into the [`UNet2DConditionModel`] it creates. Check the documentation - of [`UNet2DConditionModel`] for them. - - Parameters: - conditioning_channels (`int`, defaults to 3): - Number of channels of conditioning input (e.g. an image) - controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): - The channel order of conditional image. Will convert to `rgb` if it's `bgr`. - conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): - The tuple of output channel for each block in the `controlnet_cond_embedding` layer. - time_embedding_input_dim (`int`, defaults to 320): - Dimension of input into time embedding. Needs to be same as in the base model. - time_embedding_dim (`int`, defaults to 1280): - Dimension of output from time embedding. Needs to be same as in the base model. - learn_embedding (`bool`, defaults to `False`): - Whether to use time embedding of the control model. If yes, the time embedding is a linear interpolation of - the time embeddings of the control and base model with interpolation parameter `time_embedding_mix**3`. - time_embedding_mix (`float`, defaults to 1.0): - Linear interpolation parameter used if `learn_embedding` is `True`. A value of 1.0 means only the - control model's time embedding will be used. A value of 0.0 means only the base model's time embedding will be used. - base_model_channel_sizes (`Dict[str, List[Tuple[int]]]`): - Channel sizes of each subblock of base model. Use `gather_subblock_sizes` on your base model to compute it. - """ - - @classmethod - def init_original(cls, base_model: UNet2DConditionModel, is_sdxl=True): - """ - Create a ControlNetXS model with the same parameters as in the original paper (https://github.com/vislearn/ControlNet-XS). - - Parameters: - base_model (`UNet2DConditionModel`): - Base UNet model. Needs to be either StableDiffusion or StableDiffusion-XL. - is_sdxl (`bool`, defaults to `True`): - Whether passed `base_model` is a StableDiffusion-XL model. - """ - - def get_dim_attn_heads(base_model: UNet2DConditionModel, size_ratio: float, num_attn_heads: int): - """ - Currently, diffusers can only set the dimension of attention heads (see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why). - The original ControlNet-XS model, however, define the number of attention heads. - That's why compute the dimensions needed to get the correct number of attention heads. - """ - block_out_channels = [int(size_ratio * c) for c in base_model.config.block_out_channels] - dim_attn_heads = [math.ceil(c / num_attn_heads) for c in block_out_channels] - return dim_attn_heads - - if is_sdxl: - return ControlNetXSModel.from_unet( - base_model, - time_embedding_mix=0.95, - learn_embedding=True, - size_ratio=0.1, - conditioning_embedding_out_channels=(16, 32, 96, 256), - num_attention_heads=get_dim_attn_heads(base_model, 0.1, 64), - ) - else: - return ControlNetXSModel.from_unet( - base_model, - time_embedding_mix=1.0, - learn_embedding=True, - size_ratio=0.0125, - conditioning_embedding_out_channels=(16, 32, 96, 256), - num_attention_heads=get_dim_attn_heads(base_model, 0.0125, 8), - ) - - @classmethod - def _gather_subblock_sizes(cls, unet: UNet2DConditionModel, base_or_control: str): - """To create correctly sized connections between base and control model, we need to know - the input and output channels of each subblock. - - Parameters: - unet (`UNet2DConditionModel`): - Unet of which the subblock channels sizes are to be gathered. - base_or_control (`str`): - Needs to be either "base" or "control". If "base", decoder is also considered. - """ - if base_or_control not in ["base", "control"]: - raise ValueError("`base_or_control` needs to be either `base` or `control`") - - channel_sizes = {"down": [], "mid": [], "up": []} - - # input convolution - channel_sizes["down"].append((unet.conv_in.in_channels, unet.conv_in.out_channels)) - - # encoder blocks - for module in unet.down_blocks: - if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): - for r in module.resnets: - channel_sizes["down"].append((r.in_channels, r.out_channels)) - if module.downsamplers: - channel_sizes["down"].append( - (module.downsamplers[0].channels, module.downsamplers[0].out_channels) - ) - else: - raise ValueError(f"Encountered unknown module of type {type(module)} while creating ControlNet-XS.") - - # middle block - channel_sizes["mid"].append((unet.mid_block.resnets[0].in_channels, unet.mid_block.resnets[0].out_channels)) - - # decoder blocks - if base_or_control == "base": - for module in unet.up_blocks: - if isinstance(module, (CrossAttnUpBlock2D, UpBlock2D)): - for r in module.resnets: - channel_sizes["up"].append((r.in_channels, r.out_channels)) - else: - raise ValueError( - f"Encountered unknown module of type {type(module)} while creating ControlNet-XS." - ) - - return channel_sizes - - @register_to_config - def __init__( - self, - conditioning_channels: int = 3, - conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), - controlnet_conditioning_channel_order: str = "rgb", - time_embedding_input_dim: int = 320, - time_embedding_dim: int = 1280, - time_embedding_mix: float = 1.0, - learn_embedding: bool = False, - base_model_channel_sizes: Dict[str, List[Tuple[int]]] = { - "down": [ - (4, 320), - (320, 320), - (320, 320), - (320, 320), - (320, 640), - (640, 640), - (640, 640), - (640, 1280), - (1280, 1280), - ], - "mid": [(1280, 1280)], - "up": [ - (2560, 1280), - (2560, 1280), - (1920, 1280), - (1920, 640), - (1280, 640), - (960, 640), - (960, 320), - (640, 320), - (640, 320), - ], - }, - sample_size: Optional[int] = None, - down_block_types: Tuple[str] = ( - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "DownBlock2D", - ), - up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - norm_num_groups: Optional[int] = 32, - cross_attention_dim: Union[int, Tuple[int]] = 1280, - transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, - num_attention_heads: Optional[Union[int, Tuple[int]]] = 8, - upcast_attention: bool = False, - ): - super().__init__() - - # 1 - Create control unet - self.control_model = UNet2DConditionModel( - sample_size=sample_size, - down_block_types=down_block_types, - up_block_types=up_block_types, - block_out_channels=block_out_channels, - norm_num_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - transformer_layers_per_block=transformer_layers_per_block, - attention_head_dim=num_attention_heads, - use_linear_projection=True, - upcast_attention=upcast_attention, - time_embedding_dim=time_embedding_dim, - ) - - # 2 - Do model surgery on control model - # 2.1 - Allow to use the same time information as the base model - adjust_time_dims(self.control_model, time_embedding_input_dim, time_embedding_dim) - - # 2.2 - Allow for information infusion from base model - - # We concat the output of each base encoder subblocks to the input of the next control encoder subblock - # (We ignore the 1st element, as it represents the `conv_in`.) - extra_input_channels = [input_channels for input_channels, _ in base_model_channel_sizes["down"][1:]] - it_extra_input_channels = iter(extra_input_channels) - - for b, block in enumerate(self.control_model.down_blocks): - for r in range(len(block.resnets)): - increase_block_input_in_encoder_resnet( - self.control_model, block_no=b, resnet_idx=r, by=next(it_extra_input_channels) - ) - - if block.downsamplers: - increase_block_input_in_encoder_downsampler( - self.control_model, block_no=b, by=next(it_extra_input_channels) - ) - - increase_block_input_in_mid_resnet(self.control_model, by=extra_input_channels[-1]) - - # 2.3 - Make group norms work with modified channel sizes - adjust_group_norms(self.control_model) - - # 3 - Gather Channel Sizes - self.ch_inout_ctrl = ControlNetXSModel._gather_subblock_sizes(self.control_model, base_or_control="control") - self.ch_inout_base = base_model_channel_sizes - - # 4 - Build connections between base and control model - self.down_zero_convs_out = nn.ModuleList([]) - self.down_zero_convs_in = nn.ModuleList([]) - self.middle_block_out = nn.ModuleList([]) - self.middle_block_in = nn.ModuleList([]) - self.up_zero_convs_out = nn.ModuleList([]) - self.up_zero_convs_in = nn.ModuleList([]) - - for ch_io_base in self.ch_inout_base["down"]: - self.down_zero_convs_in.append(self._make_zero_conv(in_channels=ch_io_base[1], out_channels=ch_io_base[1])) - for i in range(len(self.ch_inout_ctrl["down"])): - self.down_zero_convs_out.append( - self._make_zero_conv(self.ch_inout_ctrl["down"][i][1], self.ch_inout_base["down"][i][1]) - ) - - self.middle_block_out = self._make_zero_conv( - self.ch_inout_ctrl["mid"][-1][1], self.ch_inout_base["mid"][-1][1] - ) - - self.up_zero_convs_out.append( - self._make_zero_conv(self.ch_inout_ctrl["down"][-1][1], self.ch_inout_base["mid"][-1][1]) - ) - for i in range(1, len(self.ch_inout_ctrl["down"])): - self.up_zero_convs_out.append( - self._make_zero_conv(self.ch_inout_ctrl["down"][-(i + 1)][1], self.ch_inout_base["up"][i - 1][1]) - ) - - # 5 - Create conditioning hint embedding - self.controlnet_cond_embedding = ControlNetConditioningEmbedding( - conditioning_embedding_channels=block_out_channels[0], - block_out_channels=conditioning_embedding_out_channels, - conditioning_channels=conditioning_channels, - ) - - # In the mininal implementation setting, we only need the control model up to the mid block - del self.control_model.up_blocks - del self.control_model.conv_norm_out - del self.control_model.conv_out - - @classmethod - def from_unet( - cls, - unet: UNet2DConditionModel, - conditioning_channels: int = 3, - conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), - controlnet_conditioning_channel_order: str = "rgb", - learn_embedding: bool = False, - time_embedding_mix: float = 1.0, - block_out_channels: Optional[Tuple[int]] = None, - size_ratio: Optional[float] = None, - num_attention_heads: Optional[Union[int, Tuple[int]]] = 8, - norm_num_groups: Optional[int] = None, - ): - r""" - Instantiate a [`ControlNetXSModel`] from [`UNet2DConditionModel`]. - - Parameters: - unet (`UNet2DConditionModel`): - The UNet model we want to control. The dimensions of the ControlNetXSModel will be adapted to it. - conditioning_channels (`int`, defaults to 3): - Number of channels of conditioning input (e.g. an image) - conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): - The tuple of output channel for each block in the `controlnet_cond_embedding` layer. - controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): - The channel order of conditional image. Will convert to `rgb` if it's `bgr`. - learn_embedding (`bool`, defaults to `False`): - Wether to use time embedding of the control model. If yes, the time embedding is a linear interpolation - of the time embeddings of the control and base model with interpolation parameter - `time_embedding_mix**3`. - time_embedding_mix (`float`, defaults to 1.0): - Linear interpolation parameter used if `learn_embedding` is `True`. - block_out_channels (`Tuple[int]`, *optional*): - Down blocks output channels in control model. Either this or `size_ratio` must be given. - size_ratio (float, *optional*): - When given, block_out_channels is set to a relative fraction of the base model's block_out_channels. - Either this or `block_out_channels` must be given. - num_attention_heads (`Union[int, Tuple[int]]`, *optional*): - The dimension of the attention heads. The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. - norm_num_groups (int, *optional*, defaults to `None`): - The number of groups to use for the normalization of the control unet. If `None`, - `int(unet.config.norm_num_groups * size_ratio)` is taken. - """ - - # Check input - fixed_size = block_out_channels is not None - relative_size = size_ratio is not None - if not (fixed_size ^ relative_size): - raise ValueError( - "Pass exactly one of `block_out_channels` (for absolute sizing) or `control_model_ratio` (for relative sizing)." - ) - - # Create model - if block_out_channels is None: - block_out_channels = [int(size_ratio * c) for c in unet.config.block_out_channels] - - # Check that attention heads and group norms match channel sizes - # - attention heads - def attn_heads_match_channel_sizes(attn_heads, channel_sizes): - if isinstance(attn_heads, (tuple, list)): - return all(c % a == 0 for a, c in zip(attn_heads, channel_sizes)) - else: - return all(c % attn_heads == 0 for c in channel_sizes) - - num_attention_heads = num_attention_heads or unet.config.attention_head_dim - if not attn_heads_match_channel_sizes(num_attention_heads, block_out_channels): - raise ValueError( - f"The dimension of attention heads ({num_attention_heads}) must divide `block_out_channels` ({block_out_channels}). If you didn't set `num_attention_heads` the default settings don't match your model. Set `num_attention_heads` manually." - ) - - # - group norms - def group_norms_match_channel_sizes(num_groups, channel_sizes): - return all(c % num_groups == 0 for c in channel_sizes) - - if norm_num_groups is None: - if group_norms_match_channel_sizes(unet.config.norm_num_groups, block_out_channels): - norm_num_groups = unet.config.norm_num_groups - else: - norm_num_groups = min(block_out_channels) - - if group_norms_match_channel_sizes(norm_num_groups, block_out_channels): - print( - f"`norm_num_groups` was set to `min(block_out_channels)` (={norm_num_groups}) so it divides all block_out_channels` ({block_out_channels}). Set it explicitly to remove this information." - ) - else: - raise ValueError( - f"`block_out_channels` ({block_out_channels}) don't match the base models `norm_num_groups` ({unet.config.norm_num_groups}). Setting `norm_num_groups` to `min(block_out_channels)` ({norm_num_groups}) didn't fix this. Pass `norm_num_groups` explicitly so it divides all block_out_channels." - ) - - def get_time_emb_input_dim(unet: UNet2DConditionModel): - return unet.time_embedding.linear_1.in_features - - def get_time_emb_dim(unet: UNet2DConditionModel): - return unet.time_embedding.linear_2.out_features - - # Clone params from base unet if - # (i) it's required to build SD or SDXL, and - # (ii) it's not used for the time embedding (as time embedding of control model is never used), and - # (iii) it's not set further below anyway - to_keep = [ - "cross_attention_dim", - "down_block_types", - "sample_size", - "transformer_layers_per_block", - "up_block_types", - "upcast_attention", - ] - kwargs = {k: v for k, v in dict(unet.config).items() if k in to_keep} - kwargs.update(block_out_channels=block_out_channels) - kwargs.update(num_attention_heads=num_attention_heads) - kwargs.update(norm_num_groups=norm_num_groups) - - # Add controlnetxs-specific params - kwargs.update( - conditioning_channels=conditioning_channels, - controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, - time_embedding_input_dim=get_time_emb_input_dim(unet), - time_embedding_dim=get_time_emb_dim(unet), - time_embedding_mix=time_embedding_mix, - learn_embedding=learn_embedding, - base_model_channel_sizes=ControlNetXSModel._gather_subblock_sizes(unet, base_or_control="base"), - conditioning_embedding_out_channels=conditioning_embedding_out_channels, - ) - - return cls(**kwargs) - - @property - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - return self.control_model.attn_processors - - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - self.control_model.set_attn_processor(processor) - - def set_default_attn_processor(self): - """ - Disables custom attention processors and sets the default attention implementation. - """ - self.control_model.set_default_attn_processor() - - def set_attention_slice(self, slice_size): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module splits the input tensor in slices to compute attention in - several steps. This is useful for saving some memory in exchange for a small decrease in speed. - - Args: - slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If - `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is - provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` - must be a multiple of `slice_size`. - """ - self.control_model.set_attention_slice(slice_size) - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (UNet2DConditionModel)): - if value: - module.enable_gradient_checkpointing() - else: - module.disable_gradient_checkpointing() - - def forward( - self, - base_model: UNet2DConditionModel, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - controlnet_cond: torch.Tensor, - conditioning_scale: float = 1.0, - class_labels: Optional[torch.Tensor] = None, - timestep_cond: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, - return_dict: bool = True, - ) -> Union[ControlNetXSOutput, Tuple]: - """ - The [`ControlNetModel`] forward method. - - Args: - base_model (`UNet2DConditionModel`): - The base unet model we want to control. - sample (`torch.FloatTensor`): - The noisy input tensor. - timestep (`Union[torch.Tensor, float, int]`): - The number of timesteps to denoise an input. - encoder_hidden_states (`torch.Tensor`): - The encoder hidden states. - controlnet_cond (`torch.FloatTensor`): - The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. - conditioning_scale (`float`, defaults to `1.0`): - How much the control model affects the base model outputs. - class_labels (`torch.Tensor`, *optional*, defaults to `None`): - Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. - timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): - Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the - timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep - embeddings. - attention_mask (`torch.Tensor`, *optional*, defaults to `None`): - An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask - is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large - negative values to the attention scores corresponding to "discard" tokens. - added_cond_kwargs (`dict`): - Additional conditions for the Stable Diffusion XL UNet. - cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): - A kwargs dictionary that if specified is passed along to the `AttnProcessor`. - return_dict (`bool`, defaults to `True`): - Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. - - Returns: - [`~models.controlnetxs.ControlNetXSOutput`] **or** `tuple`: - If `return_dict` is `True`, a [`~models.controlnetxs.ControlNetXSOutput`] is returned, otherwise a - tuple is returned where the first element is the sample tensor. - """ - # check channel order - channel_order = self.config.controlnet_conditioning_channel_order - - if channel_order == "rgb": - # in rgb order by default - ... - elif channel_order == "bgr": - controlnet_cond = torch.flip(controlnet_cond, dims=[1]) - else: - raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") - - # scale control strength - n_connections = len(self.down_zero_convs_out) + 1 + len(self.up_zero_convs_out) - scale_list = torch.full((n_connections,), conditioning_scale) - - # prepare attention_mask - if attention_mask is not None: - attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 - attention_mask = attention_mask.unsqueeze(1) - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) - - t_emb = base_model.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=sample.dtype) - - if self.config.learn_embedding: - ctrl_temb = self.control_model.time_embedding(t_emb, timestep_cond) - base_temb = base_model.time_embedding(t_emb, timestep_cond) - interpolation_param = self.config.time_embedding_mix**0.3 - - temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param) - else: - temb = base_model.time_embedding(t_emb) - - # added time & text embeddings - aug_emb = None - - if base_model.class_embedding is not None: - if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") - - if base_model.config.class_embed_type == "timestep": - class_labels = base_model.time_proj(class_labels) - - class_emb = base_model.class_embedding(class_labels).to(dtype=self.dtype) - temb = temb + class_emb - - if base_model.config.addition_embed_type is not None: - if base_model.config.addition_embed_type == "text": - aug_emb = base_model.add_embedding(encoder_hidden_states) - elif base_model.config.addition_embed_type == "text_image": - raise NotImplementedError() - elif base_model.config.addition_embed_type == "text_time": - # SDXL - style - if "text_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" - ) - text_embeds = added_cond_kwargs.get("text_embeds") - if "time_ids" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" - ) - time_ids = added_cond_kwargs.get("time_ids") - time_embeds = base_model.add_time_proj(time_ids.flatten()) - time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) - add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) - add_embeds = add_embeds.to(temb.dtype) - aug_emb = base_model.add_embedding(add_embeds) - elif base_model.config.addition_embed_type == "image": - raise NotImplementedError() - elif base_model.config.addition_embed_type == "image_hint": - raise NotImplementedError() - - temb = temb + aug_emb if aug_emb is not None else temb - - # text embeddings - cemb = encoder_hidden_states - - # Preparation - guided_hint = self.controlnet_cond_embedding(controlnet_cond) - - h_ctrl = h_base = sample - hs_base, hs_ctrl = [], [] - it_down_convs_in, it_down_convs_out, it_dec_convs_in, it_up_convs_out = map( - iter, (self.down_zero_convs_in, self.down_zero_convs_out, self.up_zero_convs_in, self.up_zero_convs_out) - ) - scales = iter(scale_list) - - base_down_subblocks = to_sub_blocks(base_model.down_blocks) - ctrl_down_subblocks = to_sub_blocks(self.control_model.down_blocks) - base_mid_subblocks = to_sub_blocks([base_model.mid_block]) - ctrl_mid_subblocks = to_sub_blocks([self.control_model.mid_block]) - base_up_subblocks = to_sub_blocks(base_model.up_blocks) - - # Cross Control - # 0 - conv in - h_base = base_model.conv_in(h_base) - h_ctrl = self.control_model.conv_in(h_ctrl) - if guided_hint is not None: - h_ctrl += guided_hint - h_base = h_base + next(it_down_convs_out)(h_ctrl) * next(scales) # D - add ctrl -> base - - hs_base.append(h_base) - hs_ctrl.append(h_ctrl) - - # 1 - down - for m_base, m_ctrl in zip(base_down_subblocks, ctrl_down_subblocks): - h_ctrl = torch.cat([h_ctrl, next(it_down_convs_in)(h_base)], dim=1) # A - concat base -> ctrl - h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock - h_ctrl = m_ctrl(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock - h_base = h_base + next(it_down_convs_out)(h_ctrl) * next(scales) # D - add ctrl -> base - hs_base.append(h_base) - hs_ctrl.append(h_ctrl) - - # 2 - mid - h_ctrl = torch.cat([h_ctrl, next(it_down_convs_in)(h_base)], dim=1) # A - concat base -> ctrl - for m_base, m_ctrl in zip(base_mid_subblocks, ctrl_mid_subblocks): - h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock - h_ctrl = m_ctrl(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock - h_base = h_base + self.middle_block_out(h_ctrl) * next(scales) # D - add ctrl -> base - - # 3 - up - for i, m_base in enumerate(base_up_subblocks): - h_base = h_base + next(it_up_convs_out)(hs_ctrl.pop()) * next(scales) # add info from ctrl encoder - h_base = torch.cat([h_base, hs_base.pop()], dim=1) # concat info from base encoder+ctrl encoder - h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) - - h_base = base_model.conv_norm_out(h_base) - h_base = base_model.conv_act(h_base) - h_base = base_model.conv_out(h_base) - - if not return_dict: - return h_base - - return ControlNetXSOutput(sample=h_base) - - def _make_zero_conv(self, in_channels, out_channels=None): - # keep running track of channels sizes - self.in_channels = in_channels - self.out_channels = out_channels or in_channels - - return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0)) - - @torch.no_grad() - def _check_if_vae_compatible(self, vae: AutoencoderKL): - condition_downscale_factor = 2 ** (len(self.config.conditioning_embedding_out_channels) - 1) - vae_downscale_factor = 2 ** (len(vae.config.block_out_channels) - 1) - compatible = condition_downscale_factor == vae_downscale_factor - return compatible, condition_downscale_factor, vae_downscale_factor - - -class SubBlock(nn.ModuleList): - """A SubBlock is the largest piece of either base or control model, that is executed independently of the other model respectively. - Before each subblock, information is concatted from base to control. And after each subblock, information is added from control to base. - """ - - def __init__(self, ms, *args, **kwargs): - if not is_iterable(ms): - ms = [ms] - super().__init__(ms, *args, **kwargs) - - def forward( - self, - x: torch.Tensor, - temb: torch.Tensor, - cemb: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - ): - """Iterate through children and pass correct information to each.""" - for m in self: - if isinstance(m, ResnetBlock2D): - x = m(x, temb) - elif isinstance(m, Transformer2DModel): - x = m(x, cemb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs).sample - elif isinstance(m, Downsample2D): - x = m(x) - elif isinstance(m, Upsample2D): - x = m(x) - else: - raise ValueError( - f"Type of m is {type(m)} but should be `ResnetBlock2D`, `Transformer2DModel`, `Downsample2D` or `Upsample2D`" - ) - - return x - - -def adjust_time_dims(unet: UNet2DConditionModel, in_dim: int, out_dim: int): - unet.time_embedding.linear_1 = nn.Linear(in_dim, out_dim) - - -def increase_block_input_in_encoder_resnet(unet: UNet2DConditionModel, block_no, resnet_idx, by): - """Increase channels sizes to allow for additional concatted information from base model""" - r = unet.down_blocks[block_no].resnets[resnet_idx] - old_norm1, old_conv1 = r.norm1, r.conv1 - # norm - norm_args = "num_groups num_channels eps affine".split(" ") - for a in norm_args: - assert hasattr(old_norm1, a) - norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args} - norm_kwargs["num_channels"] += by # surgery done here - # conv1 - conv1_args = [ - "in_channels", - "out_channels", - "kernel_size", - "stride", - "padding", - "dilation", - "groups", - "bias", - "padding_mode", - ] - if not USE_PEFT_BACKEND: - conv1_args.append("lora_layer") - - for a in conv1_args: - assert hasattr(old_conv1, a) - - conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args} - conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor. - conv1_kwargs["in_channels"] += by # surgery done here - # conv_shortcut - # as we changed the input size of the block, the input and output sizes are likely different, - # therefore we need a conv_shortcut (simply adding won't work) - conv_shortcut_args_kwargs = { - "in_channels": conv1_kwargs["in_channels"], - "out_channels": conv1_kwargs["out_channels"], - # default arguments from resnet.__init__ - "kernel_size": 1, - "stride": 1, - "padding": 0, - "bias": True, - } - # swap old with new modules - unet.down_blocks[block_no].resnets[resnet_idx].norm1 = GroupNorm(**norm_kwargs) - unet.down_blocks[block_no].resnets[resnet_idx].conv1 = ( - nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs) - ) - unet.down_blocks[block_no].resnets[resnet_idx].conv_shortcut = ( - nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs) - ) - unet.down_blocks[block_no].resnets[resnet_idx].in_channels += by # surgery done here - - -def increase_block_input_in_encoder_downsampler(unet: UNet2DConditionModel, block_no, by): - """Increase channels sizes to allow for additional concatted information from base model""" - old_down = unet.down_blocks[block_no].downsamplers[0].conv - - args = [ - "in_channels", - "out_channels", - "kernel_size", - "stride", - "padding", - "dilation", - "groups", - "bias", - "padding_mode", - ] - if not USE_PEFT_BACKEND: - args.append("lora_layer") - - for a in args: - assert hasattr(old_down, a) - kwargs = {a: getattr(old_down, a) for a in args} - kwargs["bias"] = "bias" in kwargs # as param, bias is a boolean, but as attr, it's a tensor. - kwargs["in_channels"] += by # surgery done here - # swap old with new modules - unet.down_blocks[block_no].downsamplers[0].conv = ( - nn.Conv2d(**kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**kwargs) - ) - unet.down_blocks[block_no].downsamplers[0].channels += by # surgery done here - - -def increase_block_input_in_mid_resnet(unet: UNet2DConditionModel, by): - """Increase channels sizes to allow for additional concatted information from base model""" - m = unet.mid_block.resnets[0] - old_norm1, old_conv1 = m.norm1, m.conv1 - # norm - norm_args = "num_groups num_channels eps affine".split(" ") - for a in norm_args: - assert hasattr(old_norm1, a) - norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args} - norm_kwargs["num_channels"] += by # surgery done here - conv1_args = [ - "in_channels", - "out_channels", - "kernel_size", - "stride", - "padding", - "dilation", - "groups", - "bias", - "padding_mode", - ] - if not USE_PEFT_BACKEND: - conv1_args.append("lora_layer") - - conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args} - conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor. - conv1_kwargs["in_channels"] += by # surgery done here - # conv_shortcut - # as we changed the input size of the block, the input and output sizes are likely different, - # therefore we need a conv_shortcut (simply adding won't work) - conv_shortcut_args_kwargs = { - "in_channels": conv1_kwargs["in_channels"], - "out_channels": conv1_kwargs["out_channels"], - # default arguments from resnet.__init__ - "kernel_size": 1, - "stride": 1, - "padding": 0, - "bias": True, - } - # swap old with new modules - unet.mid_block.resnets[0].norm1 = GroupNorm(**norm_kwargs) - unet.mid_block.resnets[0].conv1 = ( - nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs) - ) - unet.mid_block.resnets[0].conv_shortcut = ( - nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs) - ) - unet.mid_block.resnets[0].in_channels += by # surgery done here - - -def adjust_group_norms(unet: UNet2DConditionModel, max_num_group: int = 32): - def find_denominator(number, start): - if start >= number: - return number - while start != 0: - residual = number % start - if residual == 0: - return start - start -= 1 - - for block in [*unet.down_blocks, unet.mid_block]: - # resnets - for r in block.resnets: - if r.norm1.num_groups < max_num_group: - r.norm1.num_groups = find_denominator(r.norm1.num_channels, start=max_num_group) - - if r.norm2.num_groups < max_num_group: - r.norm2.num_groups = find_denominator(r.norm2.num_channels, start=max_num_group) - - # transformers - if hasattr(block, "attentions"): - for a in block.attentions: - if a.norm.num_groups < max_num_group: - a.norm.num_groups = find_denominator(a.norm.num_channels, start=max_num_group) - - -def is_iterable(o): - if isinstance(o, str): - return False - try: - iter(o) - return True - except TypeError: - return False - - -def to_sub_blocks(blocks): - if not is_iterable(blocks): - blocks = [blocks] - - sub_blocks = [] - - for b in blocks: - if hasattr(b, "resnets"): - if hasattr(b, "attentions") and b.attentions is not None: - for r, a in zip(b.resnets, b.attentions): - sub_blocks.append([r, a]) - - num_resnets = len(b.resnets) - num_attns = len(b.attentions) - - if num_resnets > num_attns: - # we can have more resnets than attentions, so add each resnet as separate subblock - for i in range(num_attns, num_resnets): - sub_blocks.append([b.resnets[i]]) - else: - for r in b.resnets: - sub_blocks.append([r]) - - # upsamplers are part of the same subblock - if hasattr(b, "upsamplers") and b.upsamplers is not None: - for u in b.upsamplers: - sub_blocks[-1].extend([u]) - - # downsamplers are own subblock - if hasattr(b, "downsamplers") and b.downsamplers is not None: - for d in b.downsamplers: - sub_blocks.append([d]) - - return list(map(SubBlock, sub_blocks)) - - -def zero_module(module): - for p in module.parameters(): - nn.init.zeros_(p) - return module diff --git a/examples/research_projects/controlnetxs/infer_sd_controlnetxs.py b/examples/research_projects/controlnetxs/infer_sd_controlnetxs.py deleted file mode 100644 index 722b282a3251..000000000000 --- a/examples/research_projects/controlnetxs/infer_sd_controlnetxs.py +++ /dev/null @@ -1,58 +0,0 @@ -# !pip install opencv-python transformers accelerate -import argparse - -import cv2 -import numpy as np -import torch -from controlnetxs import ControlNetXSModel -from PIL import Image -from pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline - -from diffusers.utils import load_image - - -parser = argparse.ArgumentParser() -parser.add_argument( - "--prompt", type=str, default="aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" -) -parser.add_argument("--negative_prompt", type=str, default="low quality, bad quality, sketches") -parser.add_argument("--controlnet_conditioning_scale", type=float, default=0.7) -parser.add_argument( - "--image_path", - type=str, - default="https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png", -) -parser.add_argument("--num_inference_steps", type=int, default=50) - -args = parser.parse_args() - -prompt = args.prompt -negative_prompt = args.negative_prompt -# download an image -image = load_image(args.image_path) - -# initialize the models and pipeline -controlnet_conditioning_scale = args.controlnet_conditioning_scale -controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16) -pipe = StableDiffusionControlNetXSPipeline.from_pretrained( - "stabilityai/stable-diffusion-2-1", controlnet=controlnet, torch_dtype=torch.float16 -) -pipe.enable_model_cpu_offload() - -# get canny image -image = np.array(image) -image = cv2.Canny(image, 100, 200) -image = image[:, :, None] -image = np.concatenate([image, image, image], axis=2) -canny_image = Image.fromarray(image) - -num_inference_steps = args.num_inference_steps - -# generate image -image = pipe( - prompt, - controlnet_conditioning_scale=controlnet_conditioning_scale, - image=canny_image, - num_inference_steps=num_inference_steps, -).images[0] -image.save("cnxs_sd.canny.png") diff --git a/examples/research_projects/controlnetxs/infer_sdxl_controlnetxs.py b/examples/research_projects/controlnetxs/infer_sdxl_controlnetxs.py deleted file mode 100644 index e5b8cfd88223..000000000000 --- a/examples/research_projects/controlnetxs/infer_sdxl_controlnetxs.py +++ /dev/null @@ -1,57 +0,0 @@ -# !pip install opencv-python transformers accelerate -import argparse - -import cv2 -import numpy as np -import torch -from controlnetxs import ControlNetXSModel -from PIL import Image -from pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline - -from diffusers.utils import load_image - - -parser = argparse.ArgumentParser() -parser.add_argument( - "--prompt", type=str, default="aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" -) -parser.add_argument("--negative_prompt", type=str, default="low quality, bad quality, sketches") -parser.add_argument("--controlnet_conditioning_scale", type=float, default=0.7) -parser.add_argument( - "--image_path", - type=str, - default="https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png", -) -parser.add_argument("--num_inference_steps", type=int, default=50) - -args = parser.parse_args() - -prompt = args.prompt -negative_prompt = args.negative_prompt -# download an image -image = load_image(args.image_path) -# initialize the models and pipeline -controlnet_conditioning_scale = args.controlnet_conditioning_scale -controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SDXL-canny", torch_dtype=torch.float16) -pipe = StableDiffusionControlNetXSPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 -) -pipe.enable_model_cpu_offload() - -# get canny image -image = np.array(image) -image = cv2.Canny(image, 100, 200) -image = image[:, :, None] -image = np.concatenate([image, image, image], axis=2) -canny_image = Image.fromarray(image) - -num_inference_steps = args.num_inference_steps - -# generate image -image = pipe( - prompt, - controlnet_conditioning_scale=controlnet_conditioning_scale, - image=canny_image, - num_inference_steps=num_inference_steps, -).images[0] -image.save("cnxs_sdxl.canny.png") diff --git a/examples/research_projects/controlnetxs/pipeline_controlnet_xs.py b/examples/research_projects/controlnetxs/pipeline_controlnet_xs.py deleted file mode 100644 index 32646c7c7715..000000000000 --- a/examples/research_projects/controlnetxs/pipeline_controlnet_xs.py +++ /dev/null @@ -1,901 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Any, Callable, Dict, List, Optional, Union - -import numpy as np -import PIL.Image -import torch -import torch.nn.functional as F -from controlnetxs import ControlNetXSModel -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer - -from diffusers.image_processor import PipelineImageInput, VaeImageProcessor -from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.models.lora import adjust_lora_scale_text_encoder -from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from diffusers.schedulers import KarrasDiffusionSchedulers -from diffusers.utils import ( - USE_PEFT_BACKEND, - deprecate, - logging, - scale_lora_layers, - unscale_lora_layers, -) -from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class StableDiffusionControlNetXSPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin -): - r""" - Pipeline for text-to-image generation using Stable Diffusion with ControlNet-XS guidance. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods - implemented for all pipelines (downloading, saving, running on a particular device, etc.). - - The pipeline also inherits the following loading methods: - - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights - - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights - - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. - text_encoder ([`~transformers.CLIPTextModel`]): - Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). - tokenizer ([`~transformers.CLIPTokenizer`]): - A `CLIPTokenizer` to tokenize text. - unet ([`UNet2DConditionModel`]): - A `UNet2DConditionModel` to denoise the encoded image latents. - controlnet ([`ControlNetXSModel`]): - Provides additional conditioning to the `unet` during the denoising process. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details - about a model's potential harms. - feature_extractor ([`~transformers.CLIPImageProcessor`]): - A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. - """ - - model_cpu_offload_seq = "text_encoder->unet->vae>controlnet" - _optional_components = ["safety_checker", "feature_extractor"] - _exclude_from_cpu_offload = ["safety_checker"] - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - controlnet: ControlNetXSModel, - scheduler: KarrasDiffusionSchedulers, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPImageProcessor, - requires_safety_checker: bool = True, - ): - super().__init__() - - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - - vae_compatible, cnxs_condition_downsample_factor, vae_downsample_factor = controlnet._check_if_vae_compatible( - vae - ) - if not vae_compatible: - raise ValueError( - f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXS model {cnxs_condition_downsample_factor} need to be equal. Consider building the ControlNetXS model with different `conditioning_block_sizes`." - ) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - controlnet=controlnet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) - self.control_image_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False - ) - self.register_to_config(requires_safety_checker=requires_safety_checker) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - self.vae.enable_tiling() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_tiling() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - lora_scale: Optional[float] = None, - **kwargs, - ): - deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." - deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) - - prompt_embeds_tuple = self.encode_prompt( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - do_classifier_free_guidance=do_classifier_free_guidance, - negative_prompt=negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=lora_scale, - **kwargs, - ) - - # concatenate for backwards comp - prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) - - return prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt - def encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - lora_scale: Optional[float] = None, - clip_skip: Optional[int] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - lora_scale (`float`, *optional*): - A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - """ - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, LoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - # textual inversion: process multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - if clip_skip is None: - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) - prompt_embeds = prompt_embeds[0] - else: - prompt_embeds = self.text_encoder( - text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True - ) - # Access the `hidden_states` first, that contains a tuple of - # all the hidden states from the encoder layers. Then index into - # the tuple to access the hidden states from the desired layer. - prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] - # We also need to apply the final LayerNorm here to not mess with the - # representations. The `last_hidden_states` that we typically use for - # obtaining the final prompt representations passes through the LayerNorm - # layer. - prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) - - if self.text_encoder is not None: - prompt_embeds_dtype = self.text_encoder.dtype - elif self.unet is not None: - prompt_embeds_dtype = self.unet.dtype - else: - prompt_embeds_dtype = prompt_embeds.dtype - - prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - # textual inversion: process multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - return prompt_embeds, negative_prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is None: - has_nsfw_concept = None - else: - if torch.is_tensor(image): - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") - else: - feature_extractor_input = self.image_processor.numpy_to_pil(image) - safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - return image, has_nsfw_concept - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" - deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) - - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents, return_dict=False)[0] - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def check_inputs( - self, - prompt, - image, - callback_steps, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - controlnet_conditioning_scale=1.0, - control_guidance_start=0.0, - control_guidance_end=1.0, - ): - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - # Check `image` - is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( - self.controlnet, torch._dynamo.eval_frame.OptimizedModule - ) - if ( - isinstance(self.controlnet, ControlNetXSModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetXSModel) - ): - self.check_image(image, prompt, prompt_embeds) - else: - assert False - - # Check `controlnet_conditioning_scale` - if ( - isinstance(self.controlnet, ControlNetXSModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetXSModel) - ): - if not isinstance(controlnet_conditioning_scale, float): - raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") - else: - assert False - - start, end = control_guidance_start, control_guidance_end - if start >= end: - raise ValueError( - f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." - ) - if start < 0.0: - raise ValueError(f"control guidance start: {start} can't be smaller than 0.") - if end > 1.0: - raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") - - def check_image(self, image, prompt, prompt_embeds): - image_is_pil = isinstance(image, PIL.Image.Image) - image_is_tensor = isinstance(image, torch.Tensor) - image_is_np = isinstance(image, np.ndarray) - image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) - image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) - image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) - - if ( - not image_is_pil - and not image_is_tensor - and not image_is_np - and not image_is_pil_list - and not image_is_tensor_list - and not image_is_np_list - ): - raise TypeError( - f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" - ) - - if image_is_pil: - image_batch_size = 1 - else: - image_batch_size = len(image) - - if prompt is not None and isinstance(prompt, str): - prompt_batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - prompt_batch_size = len(prompt) - elif prompt_embeds is not None: - prompt_batch_size = prompt_embeds.shape[0] - - if image_batch_size != 1 and image_batch_size != prompt_batch_size: - raise ValueError( - f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" - ) - - def prepare_image( - self, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - do_classifier_free_guidance=False, - ): - image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - image_batch_size = image.shape[0] - - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - - image = image.to(device=device, dtype=dtype) - - if do_classifier_free_guidance: - image = torch.cat([image] * 2) - - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu - def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): - r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. - - The suffixes after the scaling factors represent the stages where they are being applied. - - Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values - that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. - - Args: - s1 (`float`): - Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to - mitigate "oversmoothing effect" in the enhanced denoising process. - s2 (`float`): - Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to - mitigate "oversmoothing effect" in the enhanced denoising process. - b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. - b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. - """ - if not hasattr(self, "unet"): - raise ValueError("The pipeline must have `unet` for using FreeU.") - self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu - def disable_freeu(self): - """Disables the FreeU mechanism if enabled.""" - self.unet.disable_freeu() - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]] = None, - image: PipelineImageInput = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - controlnet_conditioning_scale: Union[float, List[float]] = 1.0, - control_guidance_start: float = 0.0, - control_guidance_end: float = 1.0, - clip_skip: Optional[int] = None, - ): - r""" - The call function to the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. - image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`, - `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): - The ControlNet input condition to provide guidance to the `unet` for generation. If the type is - specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be - accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height - and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in - `init`, images must be passed as a list such that each element of the list can be correctly batched for - input to a single ControlNet. - height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - A higher guidance scale value encourages the model to generate images closely linked to the text - `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide what to not include in image generation. If not defined, you need to - pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies - to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make - generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor is generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not - provided, text embeddings are generated from the `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If - not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generated image. Choose between `PIL.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in - [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): - The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added - to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set - the corresponding scale as a list. - control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): - The percentage of total steps at which the ControlNet starts applying. - control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): - The percentage of total steps at which the ControlNet stops applying. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - - Examples: - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, - otherwise a `tuple` is returned where the first element is a list with the generated images and the - second element is a list of `bool`s indicating whether the corresponding generated image contains - "not-safe-for-work" (nsfw) content. - """ - controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - image, - callback_steps, - negative_prompt, - prompt_embeds, - negative_prompt_embeds, - controlnet_conditioning_scale, - control_guidance_start, - control_guidance_end, - ) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None - ) - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, - clip_skip=clip_skip, - ) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - # 4. Prepare image - if isinstance(controlnet, ControlNetXSModel): - image = self.prepare_image( - image=image, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, - ) - height, width = image.shape[-2:] - else: - assert False - - # 5. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 6. Prepare latent variables - num_channels_latents = self.unet.config.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - is_unet_compiled = is_compiled_module(self.unet) - is_controlnet_compiled = is_compiled_module(self.controlnet) - is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # Relevant thread: - # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 - if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: - torch._inductor.cudagraph_mark_step_begin() - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - dont_control = ( - i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end - ) - if dont_control: - noise_pred = self.unet( - sample=latent_model_input, - timestep=t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - return_dict=True, - ).sample - else: - noise_pred = self.controlnet( - base_model=self.unet, - sample=latent_model_input, - timestep=t, - encoder_hidden_states=prompt_embeds, - controlnet_cond=image, - conditioning_scale=controlnet_conditioning_scale, - cross_attention_kwargs=cross_attention_kwargs, - return_dict=True, - ).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - step_idx = i // getattr(self.scheduler, "order", 1) - callback(step_idx, t, latents) - - # If we do sequential model offloading, let's offload unet and controlnet - # manually for max memory savings - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.unet.to("cpu") - self.controlnet.to("cpu") - torch.cuda.empty_cache() - - if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ - 0 - ] - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - else: - image = latents - has_nsfw_concept = None - - if has_nsfw_concept is None: - do_denormalize = [True] * image.shape[0] - else: - do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) - - # Offload all models - self.maybe_free_model_hooks() - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py b/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py deleted file mode 100644 index b9b390f1c00c..000000000000 --- a/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py +++ /dev/null @@ -1,1073 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import numpy as np -import PIL.Image -import torch -import torch.nn.functional as F -from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer - -from diffusers.image_processor import PipelineImageInput, VaeImageProcessor -from diffusers.loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin -from diffusers.models import AutoencoderKL, ControlNetXSModel, UNet2DConditionModel -from diffusers.models.attention_processor import ( - AttnProcessor2_0, - LoRAAttnProcessor2_0, - LoRAXFormersAttnProcessor, - XFormersAttnProcessor, -) -from diffusers.models.lora import adjust_lora_scale_text_encoder -from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput -from diffusers.schedulers import KarrasDiffusionSchedulers -from diffusers.utils import ( - USE_PEFT_BACKEND, - logging, - scale_lora_layers, - unscale_lora_layers, -) -from diffusers.utils.import_utils import is_invisible_watermark_available -from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor - - -if is_invisible_watermark_available(): - from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class StableDiffusionXLControlNetXSPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin -): - r""" - Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet-XS guidance. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods - implemented for all pipelines (downloading, saving, running on a particular device, etc.). - - The pipeline also inherits the following loading methods: - - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. - text_encoder ([`~transformers.CLIPTextModel`]): - Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). - text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]): - Second frozen text-encoder - ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)). - tokenizer ([`~transformers.CLIPTokenizer`]): - A `CLIPTokenizer` to tokenize text. - tokenizer_2 ([`~transformers.CLIPTokenizer`]): - A `CLIPTokenizer` to tokenize text. - unet ([`UNet2DConditionModel`]): - A `UNet2DConditionModel` to denoise the encoded image latents. - controlnet ([`ControlNetXSModel`]: - Provides additional conditioning to the `unet` during the denoising process. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): - Whether the negative prompt embeddings should always be set to 0. Also see the config of - `stabilityai/stable-diffusion-xl-base-1-0`. - add_watermarker (`bool`, *optional*): - Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to - watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no - watermarker is used. - """ - - # leave controlnet out on purpose because it iterates with unet - model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae->controlnet" - _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - text_encoder_2: CLIPTextModelWithProjection, - tokenizer: CLIPTokenizer, - tokenizer_2: CLIPTokenizer, - unet: UNet2DConditionModel, - controlnet: ControlNetXSModel, - scheduler: KarrasDiffusionSchedulers, - force_zeros_for_empty_prompt: bool = True, - add_watermarker: Optional[bool] = None, - ): - super().__init__() - - vae_compatible, cnxs_condition_downsample_factor, vae_downsample_factor = controlnet._check_if_vae_compatible( - vae - ) - if not vae_compatible: - raise ValueError( - f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXS model {cnxs_condition_downsample_factor} need to be equal. Consider building the ControlNetXS model with different `conditioning_block_sizes`." - ) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - text_encoder_2=text_encoder_2, - tokenizer=tokenizer, - tokenizer_2=tokenizer_2, - unet=unet, - controlnet=controlnet, - scheduler=scheduler, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) - self.control_image_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False - ) - add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() - - if add_watermarker: - self.watermark = StableDiffusionXLWatermarker() - else: - self.watermark = None - - self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - self.vae.enable_tiling() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_tiling() - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt - def encode_prompt( - self, - prompt: str, - prompt_2: Optional[str] = None, - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - do_classifier_free_guidance: bool = True, - negative_prompt: Optional[str] = None, - negative_prompt_2: Optional[str] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - lora_scale: Optional[float] = None, - clip_skip: Optional[int] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in both text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - negative_prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and - `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` - input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - """ - device = device or self._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if self.text_encoder is not None: - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - else: - scale_lora_layers(self.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - # Define tokenizers and text encoders - tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] - text_encoders = ( - [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] - ) - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # textual inversion: process multi-vector tokens if necessary - prompt_embeds_list = [] - prompts = [prompt, prompt_2] - for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, tokenizer) - - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {tokenizer.model_max_length} tokens: {removed_text}" - ) - - prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) - - # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] - if clip_skip is None: - prompt_embeds = prompt_embeds.hidden_states[-2] - else: - # "2" because SDXL always indexes from the penultimate layer. - prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] - - prompt_embeds_list.append(prompt_embeds) - - prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) - - # get unconditional embeddings for classifier free guidance - zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt - if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: - negative_prompt_embeds = torch.zeros_like(prompt_embeds) - negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) - elif do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or negative_prompt - - # normalize str to list - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_2 = ( - batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 - ) - - uncond_tokens: List[str] - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = [negative_prompt, negative_prompt_2] - - negative_prompt_embeds_list = [] - for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): - if isinstance(self, TextualInversionLoaderMixin): - negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = tokenizer( - negative_prompt, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - negative_prompt_embeds = text_encoder( - uncond_input.input_ids.to(device), - output_hidden_states=True, - ) - # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] - negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] - - negative_prompt_embeds_list.append(negative_prompt_embeds) - - negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - - if self.text_encoder_2 is not None: - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) - else: - prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - if self.text_encoder_2 is not None: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) - else: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - if do_classifier_free_guidance: - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - - if self.text_encoder is not None: - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2, lora_scale) - - return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def check_inputs( - self, - prompt, - prompt_2, - image, - callback_steps, - negative_prompt=None, - negative_prompt_2=None, - prompt_embeds=None, - negative_prompt_embeds=None, - pooled_prompt_embeds=None, - negative_pooled_prompt_embeds=None, - controlnet_conditioning_scale=1.0, - control_guidance_start=0.0, - control_guidance_end=1.0, - ): - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt_2 is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - elif negative_prompt_2 is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - if prompt_embeds is not None and pooled_prompt_embeds is None: - raise ValueError( - "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." - ) - - if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: - raise ValueError( - "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." - ) - - # Check `image` - is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( - self.controlnet, torch._dynamo.eval_frame.OptimizedModule - ) - if ( - isinstance(self.controlnet, ControlNetXSModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetXSModel) - ): - self.check_image(image, prompt, prompt_embeds) - else: - assert False - - # Check `controlnet_conditioning_scale` - if ( - isinstance(self.controlnet, ControlNetXSModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetXSModel) - ): - if not isinstance(controlnet_conditioning_scale, float): - raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") - else: - assert False - - start, end = control_guidance_start, control_guidance_end - if start >= end: - raise ValueError( - f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." - ) - if start < 0.0: - raise ValueError(f"control guidance start: {start} can't be smaller than 0.") - if end > 1.0: - raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") - - # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image - def check_image(self, image, prompt, prompt_embeds): - image_is_pil = isinstance(image, PIL.Image.Image) - image_is_tensor = isinstance(image, torch.Tensor) - image_is_np = isinstance(image, np.ndarray) - image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) - image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) - image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) - - if ( - not image_is_pil - and not image_is_tensor - and not image_is_np - and not image_is_pil_list - and not image_is_tensor_list - and not image_is_np_list - ): - raise TypeError( - f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" - ) - - if image_is_pil: - image_batch_size = 1 - else: - image_batch_size = len(image) - - if prompt is not None and isinstance(prompt, str): - prompt_batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - prompt_batch_size = len(prompt) - elif prompt_embeds is not None: - prompt_batch_size = prompt_embeds.shape[0] - - if image_batch_size != 1 and image_batch_size != prompt_batch_size: - raise ValueError( - f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" - ) - - def prepare_image( - self, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - do_classifier_free_guidance=False, - ): - image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - image_batch_size = image.shape[0] - - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - - image = image.to(device=device, dtype=dtype) - - if do_classifier_free_guidance: - image = torch.cat([image] * 2) - - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids - def _get_add_time_ids( - self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None - ): - add_time_ids = list(original_size + crops_coords_top_left + target_size) - - passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim - ) - expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features - - if expected_add_embed_dim != passed_add_embed_dim: - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." - ) - - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - return add_time_ids - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae - def upcast_vae(self): - dtype = self.vae.dtype - self.vae.to(dtype=torch.float32) - use_torch_2_0_or_xformers = isinstance( - self.vae.decoder.mid_block.attentions[0].processor, - ( - AttnProcessor2_0, - XFormersAttnProcessor, - LoRAXFormersAttnProcessor, - LoRAAttnProcessor2_0, - ), - ) - # if xformers or torch_2_0 is used attention block does not need - # to be in float32 which can save lots of memory - if use_torch_2_0_or_xformers: - self.vae.post_quant_conv.to(dtype) - self.vae.decoder.conv_in.to(dtype) - self.vae.decoder.mid_block.to(dtype) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu - def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): - r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. - - The suffixes after the scaling factors represent the stages where they are being applied. - - Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values - that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. - - Args: - s1 (`float`): - Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to - mitigate "oversmoothing effect" in the enhanced denoising process. - s2 (`float`): - Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to - mitigate "oversmoothing effect" in the enhanced denoising process. - b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. - b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. - """ - if not hasattr(self, "unet"): - raise ValueError("The pipeline must have `unet` for using FreeU.") - self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu - def disable_freeu(self): - """Disables the FreeU mechanism if enabled.""" - self.unet.disable_freeu() - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]] = None, - prompt_2: Optional[Union[str, List[str]]] = None, - image: PipelineImageInput = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 5.0, - negative_prompt: Optional[Union[str, List[str]]] = None, - negative_prompt_2: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - controlnet_conditioning_scale: Union[float, List[float]] = 1.0, - control_guidance_start: float = 0.0, - control_guidance_end: float = 1.0, - original_size: Tuple[int, int] = None, - crops_coords_top_left: Tuple[int, int] = (0, 0), - target_size: Tuple[int, int] = None, - negative_original_size: Optional[Tuple[int, int]] = None, - negative_crops_coords_top_left: Tuple[int, int] = (0, 0), - negative_target_size: Optional[Tuple[int, int]] = None, - clip_skip: Optional[int] = None, - ): - r""" - The call function to the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in both text-encoders. - image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`, - `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): - The ControlNet input condition to provide guidance to the `unet` for generation. If the type is - specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be - accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height - and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in - `init`, images must be passed as a list such that each element of the list can be correctly batched for - input to a single ControlNet. - height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): - The height in pixels of the generated image. Anything below 512 pixels won't work well for - [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) - and checkpoints that are not specifically fine-tuned on low resolutions. - width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): - The width in pixels of the generated image. Anything below 512 pixels won't work well for - [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) - and checkpoints that are not specifically fine-tuned on low resolutions. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 5.0): - A higher guidance scale value encourages the model to generate images closely linked to the text - `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide what to not include in image generation. If not defined, you need to - pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). - negative_prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` - and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies - to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make - generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor is generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not - provided, text embeddings are generated from the `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If - not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If - not provided, pooled text embeddings are generated from `prompt` input argument. - negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt - weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input - argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generated image. Choose between `PIL.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in - [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): - The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added - to the residual in the original `unet`. - control_guidance_start (`float`, *optional*, defaults to 0.0): - The percentage of total steps at which the ControlNet starts applying. - control_guidance_end (`float`, *optional*, defaults to 1.0): - The percentage of total steps at which the ControlNet stops applying. - original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): - If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. - `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as - explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). - crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): - `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position - `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting - `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). - target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): - For most cases, `target_size` should be set to the desired height and width of the generated image. If - not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in - section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). - negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): - To negatively condition the generation process based on a specific image resolution. Part of SDXL's - micro-conditioning as explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more - information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. - negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): - To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's - micro-conditioning as explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more - information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. - negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): - To negatively condition the generation process based on a target image resolution. It should be as same - as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more - information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - - Examples: - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] is - returned, otherwise a `tuple` is returned containing the output images. - """ - controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - prompt_2, - image, - callback_steps, - negative_prompt, - negative_prompt_2, - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - controlnet_conditioning_scale, - control_guidance_start, - control_guidance_end, - ) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None - ) - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = self.encode_prompt( - prompt, - prompt_2, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - negative_prompt_2, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - lora_scale=text_encoder_lora_scale, - clip_skip=clip_skip, - ) - - # 4. Prepare image - if isinstance(controlnet, ControlNetXSModel): - image = self.prepare_image( - image=image, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, - ) - height, width = image.shape[-2:] - else: - assert False - - # 5. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 6. Prepare latent variables - num_channels_latents = self.unet.config.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 7.1 Prepare added time ids & embeddings - if isinstance(image, list): - original_size = original_size or image[0].shape[-2:] - else: - original_size = original_size or image.shape[-2:] - target_size = target_size or (height, width) - - add_text_embeds = pooled_prompt_embeds - if self.text_encoder_2 is None: - text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) - else: - text_encoder_projection_dim = self.text_encoder_2.config.projection_dim - - add_time_ids = self._get_add_time_ids( - original_size, - crops_coords_top_left, - target_size, - dtype=prompt_embeds.dtype, - text_encoder_projection_dim=text_encoder_projection_dim, - ) - - if negative_original_size is not None and negative_target_size is not None: - negative_add_time_ids = self._get_add_time_ids( - negative_original_size, - negative_crops_coords_top_left, - negative_target_size, - dtype=prompt_embeds.dtype, - text_encoder_projection_dim=text_encoder_projection_dim, - ) - else: - negative_add_time_ids = add_time_ids - - if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) - add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) - - prompt_embeds = prompt_embeds.to(device) - add_text_embeds = add_text_embeds.to(device) - add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) - - # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - is_unet_compiled = is_compiled_module(self.unet) - is_controlnet_compiled = is_compiled_module(self.controlnet) - is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # Relevant thread: - # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 - if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: - torch._inductor.cudagraph_mark_step_begin() - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - - # predict the noise residual - dont_control = ( - i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end - ) - if dont_control: - noise_pred = self.unet( - sample=latent_model_input, - timestep=t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - return_dict=True, - ).sample - else: - noise_pred = self.controlnet( - base_model=self.unet, - sample=latent_model_input, - timestep=t, - encoder_hidden_states=prompt_embeds, - controlnet_cond=image, - conditioning_scale=controlnet_conditioning_scale, - cross_attention_kwargs=cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - return_dict=True, - ).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - step_idx = i // getattr(self.scheduler, "order", 1) - callback(step_idx, t, latents) - - if not output_type == "latent": - # make sure the VAE is in float32 mode, as it overflows in float16 - needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast - - if needs_upcasting: - self.upcast_vae() - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - - # cast back to fp16 if needed - if needs_upcasting: - self.vae.to(dtype=torch.float16) - else: - image = latents - - if not output_type == "latent": - # apply watermark if available - if self.watermark is not None: - image = self.watermark.apply_watermark(image) - - image = self.image_processor.postprocess(image, output_type=output_type) - - # Offload all models - self.maybe_free_model_hooks() - - if not return_dict: - return (image,) - - return StableDiffusionXLPipelineOutput(images=image) From 4185e86ff66ab6108f0021b91f7828629287aad1 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Thu, 29 Feb 2024 19:54:41 +0100 Subject: [PATCH 45/75] Make style, quality, fix-copies --- .../pipelines/controlnet_xs/pipeline_controlnet_xs.py | 4 ++-- tests/pipelines/controlnet_xs/test_controlnetxs.py | 8 ++++++-- tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py | 8 ++++++-- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index f3f560ac30e4..c05813198ea5 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -357,7 +357,7 @@ def encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: - # textual inversion: procecss multi-vector tokens if necessary + # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) @@ -439,7 +439,7 @@ def encode_prompt( else: uncond_tokens = negative_prompt - # textual inversion: procecss multi-vector tokens if necessary + # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py index f8fba1f96c8e..1dbeece8b01e 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py @@ -263,7 +263,9 @@ def test_save_load_local(self, expected_max_difference=5e-4): pipe.get_controlnet_addon().save_pretrained(tmpdir_addon, safe_serialization=False) addon_loaded = ControlNetXSAddon.from_pretrained(tmpdir_addon) - pipe_loaded = self.pipeline_class.from_pretrained(base_path=tmpdir_components, controlnet_addon=addon_loaded) + pipe_loaded = self.pipeline_class.from_pretrained( + base_path=tmpdir_components, controlnet_addon=addon_loaded + ) for component in pipe_loaded.components.values(): if hasattr(component, "set_default_attn_processor"): @@ -301,7 +303,9 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4): pipe.get_controlnet_addon().save_pretrained(tmpdir_addon, safe_serialization=False) addon_loaded = ControlNetXSAddon.from_pretrained(tmpdir_addon) - pipe_loaded = self.pipeline_class.from_pretrained(base_path=tmpdir_components, controlnet_addon=addon_loaded) + pipe_loaded = self.pipeline_class.from_pretrained( + base_path=tmpdir_components, controlnet_addon=addon_loaded + ) for component in pipe_loaded.components.values(): if hasattr(component, "set_default_attn_processor"): diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py index e854ede94259..b2bea4c6810d 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py @@ -332,7 +332,9 @@ def test_save_load_local(self, expected_max_difference=5e-4): pipe.get_controlnet_addon().save_pretrained(tmpdir_addon, safe_serialization=False) addon_loaded = ControlNetXSAddon.from_pretrained(tmpdir_addon) - pipe_loaded = self.pipeline_class.from_pretrained(base_path=tmpdir_components, controlnet_addon=addon_loaded) + pipe_loaded = self.pipeline_class.from_pretrained( + base_path=tmpdir_components, controlnet_addon=addon_loaded + ) for component in pipe_loaded.components.values(): if hasattr(component, "set_default_attn_processor"): @@ -391,7 +393,9 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4): pipe.get_controlnet_addon().save_pretrained(tmpdir_addon, safe_serialization=False) addon_loaded = ControlNetXSAddon.from_pretrained(tmpdir_addon) - pipe_loaded = self.pipeline_class.from_pretrained(base_path=tmpdir_components, controlnet_addon=addon_loaded) + pipe_loaded = self.pipeline_class.from_pretrained( + base_path=tmpdir_components, controlnet_addon=addon_loaded + ) for component in pipe_loaded.components.values(): if hasattr(component, "set_default_attn_processor"): From 49049cbc8e281cd222ad9445ecc28795a6c51726 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Thu, 29 Feb 2024 21:03:46 +0100 Subject: [PATCH 46/75] Small fixes - deleted ControlNetXSModel.init_original - added time_embedding_mix to StableDiffusionControlNetXSPipeline .from_pretrained / StableDiffusionXLControlNetXSPipeline.from_pretrained - fixed copy hints --- src/diffusers/models/controlnet_xs.py | 42 ------------------- .../controlnet_xs/pipeline_controlnet_xs.py | 19 +++++---- .../pipeline_controlnet_xs_sd_xl.py | 19 +++++---- 3 files changed, 20 insertions(+), 60 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 5176044187da..890b9995421d 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union @@ -471,47 +470,6 @@ class ControlNetXSModel(nn.Module): Otherwise, both are combined. """ - @classmethod - def init_original(cls, base_model: UNet2DConditionModel, is_sdxl=False): - """ - Create a `ControlNetXSModel` model with the same parameters as in the original paper (https://github.com/vislearn/ControlNet-XS). - - Parameters: - base_model (`UNet2DConditionModel`): - Base UNet model. Needs to be either StableDiffusion or StableDiffusion-XL. - is_sdxl (`bool`, defaults to `False`): - Whether passed `base_model` is a StableDiffusion-XL model. - """ - - def get_dim_attn_heads(base_model: UNet2DConditionModel, size_ratio: float, num_attn_heads: int): - """ - Currently, diffusers can only set the dimension of attention heads (see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why). - The original ControlNet-XS model, however, define the number of attention heads. - That's why we compute the dimensions needed to get the correct number of attention heads. - """ - block_out_channels = [int(size_ratio * c) for c in base_model.config.block_out_channels] - dim_attn_heads = [math.ceil(c / num_attn_heads) for c in block_out_channels] - return dim_attn_heads - - if is_sdxl: - time_embedding_mix = 0.95 - controlnet_addon = ControlNetXSAddon.from_unet( - base_model, - learn_time_embedding=True, - size_ratio=0.1, - num_attention_heads=get_dim_attn_heads(base_model, 0.1, 64), - ) - else: - time_embedding_mix = 1.0 - controlnet_addon = ControlNetXSAddon.from_unet( - base_model, - learn_time_embedding=True, - size_ratio=0.0125, - num_attention_heads=get_dim_attn_heads(base_model, 0.0125, 8), - ) - - return cls(base_model=base_model, ctrl_addon=controlnet_addon, time_embedding_mix=time_embedding_mix) - def __init__( self, base_model: UNet2DConditionModel, diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index c05813198ea5..39459e75b989 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -48,7 +48,7 @@ Examples: ```py >>> # !pip install opencv-python transformers accelerate - >>> from diffusers import StableDiffusionControlNetXSPipeline, ControlNetXSModel + >>> from diffusers import StableDiffusionControlNetXSPipeline, ControlNetXSAddon >>> from diffusers.utils import load_image >>> import numpy as np >>> import torch @@ -71,8 +71,9 @@ ... "UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16 ... ) >>> pipe = StableDiffusionControlNetXSPipeline.from_pretrained( - ... "stabilityai/stable-diffusion-2-1-base", controlnet_xs_addon=controlnet_xs_addon, torch_dtype=torch.float16 - ... ) + ... "stabilityai/stable-diffusion-2-1-base", controlnet_xs_addon=controlnet_xs_addon, + ... time_embedding_mix=1.0, torch_dtype=torch.float16 + ... ) # paper used time_embedding_mix=1.0 >>> pipe.enable_model_cpu_offload() >>> # get canny image @@ -185,7 +186,7 @@ def __init__( self.register_to_config(requires_safety_checker=requires_safety_checker) @classmethod - def from_pretrained(cls, base_path, controlnet_addon, **kwargs): + def from_pretrained(cls, base_path, controlnet_addon, time_embedding_mix=1.0, **kwargs): """ Instantiates pipeline from a `StableDiffusionPipeline` and a `ControlNetXSAddon`. @@ -211,7 +212,7 @@ def from_pretrained(cls, base_path, controlnet_addon, **kwargs): components = {k: v for k, v in components.items() if k not in ["unet"] + to_ignore} - controlnet = ControlNetXSModel(unet, controlnet_addon) + controlnet = ControlNetXSModel(unet, controlnet_addon, time_embedding_mix) return StableDiffusionControlNetXSPipeline(controlnet=controlnet, **components) def save_pretrained(self, *args, **kwargs): @@ -230,7 +231,7 @@ def get_controlnet_addon(self): """Get the `ControlNetXSAddon` model.""" return self.components["controlnet"].ctrl_addon - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to @@ -238,7 +239,7 @@ def enable_vae_slicing(self): """ self.vae.enable_slicing() - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to @@ -246,7 +247,7 @@ def disable_vae_slicing(self): """ self.vae.disable_slicing() - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_tiling def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to @@ -255,7 +256,7 @@ def enable_vae_tiling(self): """ self.vae.enable_tiling() - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_tiling def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index 9de9557ab83d..72829e084df7 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -64,7 +64,7 @@ Examples: ```py >>> # !pip install opencv-python transformers accelerate - >>> from diffusers import StableDiffusionXLControlNetXSPipeline, ControlNetXSModel, AutoencoderKL + >>> from diffusers import StableDiffusionXLControlNetXSPipeline, ControlNetXSAddon, AutoencoderKL >>> from diffusers.utils import load_image >>> import numpy as np >>> import torch @@ -87,8 +87,9 @@ ... "UmerHA/Testing-ConrolNetXS-SDXL-canny", torch_dtype=torch.float16 ... ) >>> pipe = StableDiffusionControlNetXSPipeline.from_pretrained( - ... base_path="stabilityai/stable-diffusion-xl-base-1.0", controlnet_xs_addon=controlnet_xs_addon, torch_dtype=torch.float16 - ... ) + ... base_path="stabilityai/stable-diffusion-xl-base-1.0", controlnet_xs_addon=controlnet_xs_addon, + ... time_embedding_mix=0.95, torch_dtype=torch.float16 + ... ) # paper used time_embedding_mix=0.95 >>> pipe.enable_model_cpu_offload() >>> # get canny image @@ -211,7 +212,7 @@ def __init__( self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) @classmethod - def from_pretrained(cls, base_path, controlnet_addon, **kwargs): + def from_pretrained(cls, base_path, controlnet_addon, time_embedding_mix=1.0, **kwargs): """ Instantiates pipeline from a `StableDiffusionXLPipeline` and a `ControlNetXSAddon`. @@ -237,7 +238,7 @@ def from_pretrained(cls, base_path, controlnet_addon, **kwargs): components = {k: v for k, v in components.items() if k not in ["unet"] + to_ignore} - controlnet = ControlNetXSModel(unet, controlnet_addon) + controlnet = ControlNetXSModel(unet, controlnet_addon, time_embedding_mix) return StableDiffusionXLControlNetXSPipeline(controlnet=controlnet, **components) def save_pretrained(self, *args, **kwargs): @@ -256,7 +257,7 @@ def get_controlnet_addon(self): """Get the `ControlNetXSAddon` model.""" return self.components["controlnet"].ctrl_addon - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to @@ -264,7 +265,7 @@ def enable_vae_slicing(self): """ self.vae.enable_slicing() - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to @@ -272,7 +273,7 @@ def disable_vae_slicing(self): """ self.vae.disable_slicing() - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_tiling def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to @@ -281,7 +282,7 @@ def enable_vae_tiling(self): """ self.vae.enable_tiling() - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_tiling def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to From 7dd6b05ac22fa33e29f5af6fb4e8ec8c349debe0 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Mon, 11 Mar 2024 16:20:39 +0100 Subject: [PATCH 47/75] checkin May 11 '23 --- src/diffusers/__init__.py | 4 +- src/diffusers/models/__init__.py | 4 +- src/diffusers/models/controlnet_xs.py | 265 +++++++++++++++--- .../controlnet_xs/pipeline_controlnet_xs.py | 71 ++--- .../pipeline_controlnet_xs_sd_xl.py | 44 +-- .../controlnet_xs/test_controlnetxs.py | 140 +-------- .../controlnet_xs/test_controlnetxs_sdxl.py | 167 +---------- 7 files changed, 285 insertions(+), 410 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 81a6f078fc60..235d2d9b1672 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -81,7 +81,6 @@ "ConsistencyDecoderVAE", "ControlNetModel", "ControlNetXSAddon", - "ControlNetXSModel", "I2VGenXLUNet", "Kandinsky3UNet", "ModelMixin", @@ -95,6 +94,7 @@ "UNet2DConditionModel", "UNet2DModel", "UNet3DConditionModel", + "UNetControlNetXSModel", "UNetMotionModel", "UNetSpatioTemporalConditionModel", "UVit2DModel", @@ -472,7 +472,6 @@ ConsistencyDecoderVAE, ControlNetModel, ControlNetXSAddon, - ControlNetXSModel, I2VGenXLUNet, Kandinsky3UNet, ModelMixin, @@ -486,6 +485,7 @@ UNet2DConditionModel, UNet2DModel, UNet3DConditionModel, + UNetControlNetXSModel, UNetMotionModel, UNetSpatioTemporalConditionModel, UVit2DModel, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index a3c88c642650..9b1e3044982f 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -32,7 +32,7 @@ _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["controlnet"] = ["ControlNetModel"] - _import_structure["controlnet_xs"] = ["ControlNetXSAddon", "ControlNetXSModel"] + _import_structure["controlnet_xs"] = ["ControlNetXSAddon", "UNetControlNetXSModel"] _import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"] _import_structure["embeddings"] = ["ImageProjection"] _import_structure["modeling_utils"] = ["ModelMixin"] @@ -68,7 +68,7 @@ ConsistencyDecoderVAE, ) from .controlnet import ControlNetModel - from .controlnet_xs import ControlNetXSAddon, ControlNetXSModel + from .controlnet_xs import ControlNetXSAddon, UNetControlNetXSModel from .embeddings import ImageProjection from .modeling_utils import ModelMixin from .transformers import ( diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 890b9995421d..515ac8581aab 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -449,7 +449,7 @@ def _make_zero_conv(self, in_channels, out_channels=None): return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0)) -class ControlNetXSModel(nn.Module): +class UNetControlNetXSModel(ModelMixin, ConfigMixin): r""" A ControlNet-XS model @@ -470,26 +470,89 @@ class ControlNetXSModel(nn.Module): Otherwise, both are combined. """ + @register_to_config def __init__( self, - base_model: UNet2DConditionModel, - ctrl_addon: ControlNetXSAddon, + # unet configs + conditioning_channels: int = 3, + conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), + time_embedding_input_dim: int = 320, + time_embedding_dim: int = 1280, time_embedding_mix: float = 1.0, + base_model_channel_sizes: Dict[str, List[Tuple[int]]] = { + "down": [ + (4, 320), + (320, 320), + (320, 320), + (320, 320), + (320, 640), + (640, 640), + (640, 640), + (640, 1280), + (1280, 1280), + ], + "mid": [(1280, 1280)], + "up": [ + (2560, 1280), + (2560, 1280), + (1920, 1280), + (1920, 640), + (1280, 640), + (960, 640), + (960, 320), + (640, 320), + (640, 320), + ], + }, + sample_size: Optional[int] = 96, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), # for addon: (4, 8, 16, 16) + norm_num_groups: Optional[int] = 32, + cross_attention_dim: Union[int, Tuple[int]] = 1024, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, # type Tuple[Tuple] necessary? + num_attention_heads: Optional[Union[int, Tuple[int]]] = 8, + upcast_attention: bool = True, + # controlnet configs + controlnet_conditioning_channel_order: str = "rgb", + conditioning_learn_time_embedding: bool = False, + channels_base: Dict[str, List[Tuple[int]]] = { + "down - out": [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280], + "mid - out": 1280, + "up - in": [1280, 1280, 1280, 1280, 1280, 1280, 1280, 640, 640, 640, 320, 320], + }, + attention_head_dim: Union[int, Tuple[int]] = 4, + max_norm_num_groups: int = 32, ): super().__init__() if time_embedding_mix < 0 or time_embedding_mix > 1: raise ValueError("`time_embedding_mix` needs to be between 0 and 1.") - if time_embedding_mix < 1 and not ctrl_addon.config.learn_time_embedding: + if time_embedding_mix < 1 and not conditioning_learn_time_embedding: raise ValueError( "To use `time_embedding_mix` < 1, initialize `ctrl_addon` with `learn_time_embedding = True`" ) - self.ctrl_addon = ctrl_addon - self.base_model = base_model - self.time_embedding_mix = time_embedding_mix + # Create UNet and decompose it into subblocks, which we then save + base_model = UNet2DConditionModel( + sample_size=sample_size, + down_block_types=down_block_types, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + transformer_layers_per_block=transformer_layers_per_block, + attention_head_dim=num_attention_heads, + use_linear_projection=True, + upcast_attention=upcast_attention, + time_embedding_dim=time_embedding_dim, + ) - # Decompose blocks of base model into subblocks self.base_down_subblocks = nn.ModuleList() self.base_up_subblocks = nn.ModuleList() @@ -526,20 +589,134 @@ def __init__( for r, a, u in zip(resnets, attentions, upsamplers): self.base_up_subblocks.append(CrossAttnUpSubBlock2D.from_modules(r, a, u)) - @property - def device(self) -> torch.device: - """ - `torch.device`: The device on which the module is (assuming that all the module parameters are on the same - device). - """ - return self.base_model.device + self.control_addon = ControlNetXSAddon( + conditioning_channels=conditioning_channels, + conditioning_channel_order=controlnet_conditioning_channel_order, + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + time_embedding_input_dim=time_embedding_input_dim, + time_embedding_dim=time_embedding_dim, + learn_time_embedding=conditioning_learn_time_embedding, + channels_base=channels_base, + attention_head_dim=attention_head_dim, + block_out_channels=block_out_channels, + cross_attention_dim=cross_attention_dim, + down_block_types=down_block_types, + sample_size=sample_size, + transformer_layers_per_block=transformer_layers_per_block, + upcast_attention=upcast_attention, + max_norm_num_groups=max_norm_num_groups + ) - @property - def dtype(self) -> torch.dtype: - """ - `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). - """ - return self.base_model.dtype + self.time_embedding_mix = time_embedding_mix + + # todo umer + @classmethod + def from_unet2d( + cls, + unet: UNet2DConditionModel, + controlnet: ControlNetXSAddon, + load_weights: bool = True, + ): + # analogous to diffusers.models.unets.unet_motion_model.UNetMotionModel.from_unet2d + config = unet.config + config["_class_name"] = cls.__name__ + + down_blocks = [] + for down_blocks_type in config["down_block_types"]: + if "CrossAttn" in down_blocks_type: + down_blocks.append("CrossAttnDownBlockMotion") + else: + down_blocks.append("DownBlockMotion") + config["down_block_types"] = down_blocks + + up_blocks = [] + for down_blocks_type in config["up_block_types"]: + if "CrossAttn" in down_blocks_type: + up_blocks.append("CrossAttnUpBlockMotion") + else: + up_blocks.append("UpBlockMotion") + + config["up_block_types"] = up_blocks + + if has_motion_adapter: + config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"] + config["motion_max_seq_length"] = motion_adapter.config["motion_max_seq_length"] + config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"] + + # For PIA UNets we need to set the number input channels to 9 + if motion_adapter.config["conv_in_channels"]: + config["in_channels"] = motion_adapter.config["conv_in_channels"] + + # Need this for backwards compatibility with UNet2DConditionModel checkpoints + if not config.get("num_attention_heads"): + config["num_attention_heads"] = config["attention_head_dim"] + + model = cls.from_config(config) + + if not load_weights: + return model + + # Logic for loading PIA UNets which allow the first 4 channels to be any UNet2DConditionModel conv_in weight + # while the last 5 channels must be PIA conv_in weights. + if has_motion_adapter and motion_adapter.config["conv_in_channels"]: + model.conv_in = motion_adapter.conv_in + updated_conv_in_weight = torch.cat( + [unet.conv_in.weight, motion_adapter.conv_in.weight[:, 4:, :, :]], dim=1 + ) + model.conv_in.load_state_dict({"weight": updated_conv_in_weight, "bias": unet.conv_in.bias}) + else: + model.conv_in.load_state_dict(unet.conv_in.state_dict()) + + model.time_proj.load_state_dict(unet.time_proj.state_dict()) + model.time_embedding.load_state_dict(unet.time_embedding.state_dict()) + + for i, down_block in enumerate(unet.down_blocks): + model.down_blocks[i].resnets.load_state_dict(down_block.resnets.state_dict()) + if hasattr(model.down_blocks[i], "attentions"): + model.down_blocks[i].attentions.load_state_dict(down_block.attentions.state_dict()) + if model.down_blocks[i].downsamplers: + model.down_blocks[i].downsamplers.load_state_dict(down_block.downsamplers.state_dict()) + + for i, up_block in enumerate(unet.up_blocks): + model.up_blocks[i].resnets.load_state_dict(up_block.resnets.state_dict()) + if hasattr(model.up_blocks[i], "attentions"): + model.up_blocks[i].attentions.load_state_dict(up_block.attentions.state_dict()) + if model.up_blocks[i].upsamplers: + model.up_blocks[i].upsamplers.load_state_dict(up_block.upsamplers.state_dict()) + + model.mid_block.resnets.load_state_dict(unet.mid_block.resnets.state_dict()) + model.mid_block.attentions.load_state_dict(unet.mid_block.attentions.state_dict()) + + if unet.conv_norm_out is not None: + model.conv_norm_out.load_state_dict(unet.conv_norm_out.state_dict()) + if unet.conv_act is not None: + model.conv_act.load_state_dict(unet.conv_act.state_dict()) + model.conv_out.load_state_dict(unet.conv_out.state_dict()) + + if has_motion_adapter: + model.load_motion_modules(motion_adapter) + + # ensure that the Motion UNet is the same dtype as the UNet2DConditionModel + model.to(unet.dtype) + + return model + + + # todo umer + def load_controlnet_addon(self, controlnet: ControlNetXSAddon) -> None: + pass + + # todo umer + def save_controlnet_addon( + self, + save_directory: str, + is_main_process: bool = True, + safe_serialization: bool = True, + variant: Optional[str] = None, + push_to_hub: bool = False, + **kwargs, + ) -> None: + pass @torch.no_grad() def _check_if_vae_compatible(self, vae: AutoencoderKL): @@ -602,19 +779,6 @@ def forward( tuple is returned where the first element is the sample tensor. """ - if not do_control: - return self.base_model( - sample=sample, - timestep=timestep, - encoder_hidden_states=encoder_hidden_states, - class_labels=class_labels, - timestep_cond=timestep_cond, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - return_dict=return_dict, - ) - # check channel order if self.ctrl_addon.config.conditioning_channel_order == "bgr": controlnet_cond = torch.flip(controlnet_cond, dims=[1]) @@ -715,6 +879,38 @@ def forward( mid_zero_convs_c2b = self.ctrl_addon.mid_zero_convs_c2b up_zero_convs_c2b = self.ctrl_addon.up_zero_convs_c2b + if not do_control: + # Run the base model without control + + # 1 - conv in & down + h_base = self.base_model.conv_in(h_base) + hs_base.append(h_base) + + for b in base_down_subblocks: + if isinstance(b, CrossAttnSubBlock2D): + additional_params = [temb, cemb, attention_mask, cross_attention_kwargs] + else: + additional_params = [] + h_base = b(h_base, *additional_params) + hs_base.append(h_base) + + # 2 - mid + h_base = self.base_model.mid_block(h_base, temb, cemb, attention_mask, cross_attention_kwargs) + + # 3 - up + for b, skip_b in zip(self.base_up_subblocks, reversed(hs_base)): + h_base = torch.cat([h_base, skip_b], dim=1) # concat info from base encoder + h_base = b(h_base, temb, cemb, attention_mask, cross_attention_kwargs) + + h_base = self.base_model.conv_norm_out(h_base) + h_base = self.base_model.conv_act(h_base) + h_base = self.base_model.conv_out(h_base) + + if not return_dict: + return h_base + + return ControlNetXSOutput(sample=h_base) + # 1 - conv in & down # The base -> ctrl connections are "delayed" by 1 subblock, because we want to "wait" to ensure the new information from the last ctrl -> base connection is also considered. # Therefore, the connections iterate over: @@ -767,6 +963,7 @@ def forward( h_base = torch.cat([h_base, skip_b], dim=1) # concat info from base encoder+ctrl encoder h_base = b(h_base, temb, cemb, attention_mask, cross_attention_kwargs) + # 4 - conv out h_base = self.base_model.conv_norm_out(h_base) h_base = self.base_model.conv_act(h_base) h_base = self.base_model.conv_out(h_base) diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index 39459e75b989..568eb8a93286 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -23,7 +23,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ControlNetXSModel +from ...models import AutoencoderKL, ControlNetXSAddon, UNet2DConditionModel, UNetControlNetXSModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -125,7 +125,8 @@ class StableDiffusionControlNetXSPipeline( A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ - model_cpu_offload_seq = "text_encoder->controlnet->vae" + # todo: dont load controlnet to gpu + model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] @@ -135,7 +136,8 @@ def __init__( vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, - controlnet: ControlNetXSModel, + unet: Union[UNet2DConditionModel, UNetControlNetXSModel], + controlnet: ControlNetXSAddon, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, @@ -143,6 +145,9 @@ def __init__( ): super().__init__() + if isinstance(unet, UNet2DConditionModel): + unet = UNetControlNetXSModel(unet, controlnet) + if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" @@ -163,7 +168,7 @@ def __init__( vae_compatible, cnxs_condition_downsample_factor, vae_downsample_factor, - ) = controlnet._check_if_vae_compatible(vae) + ) = unet._check_if_vae_compatible(vae) if not vae_compatible: raise ValueError( f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXSAddon model ({cnxs_condition_downsample_factor}) need to be equal. Consider building the ControlNetXSAddon model with different `conditioning_embedding_out_channels`." @@ -173,6 +178,7 @@ def __init__( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, + unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=safety_checker, @@ -185,42 +191,6 @@ def __init__( ) self.register_to_config(requires_safety_checker=requires_safety_checker) - @classmethod - def from_pretrained(cls, base_path, controlnet_addon, time_embedding_mix=1.0, **kwargs): - """ - Instantiates pipeline from a `StableDiffusionPipeline` and a `ControlNetXSAddon`. - - Arguments: - base_path (`str` or `os.PathLike`): - Directory to load underlying `StableDiffusionPipeline` from. - controlnet_addon (`ControlNetXSAddon`): - A `ControlNetXSAddon` model. - kwargs (`Dict[str, Any]`, *optional*): - Additional keyword arguments passed along to the [`~StableDiffusionPipeline.from_pretrained`] method. - """ - - components = StableDiffusionPipeline.from_pretrained(base_path, **kwargs).components - - unet = components["unet"] - - to_ignore = ["image_encoder"] - for item in to_ignore: - if item in components: - print( - f"Loaded base pipeline has component `{item}` which StableDiffusionControlNetXSPipeline can't use. It will be ignored." - ) - - components = {k: v for k, v in components.items() if k not in ["unet"] + to_ignore} - - controlnet = ControlNetXSModel(unet, controlnet_addon, time_embedding_mix) - return StableDiffusionControlNetXSPipeline(controlnet=controlnet, **components) - - def save_pretrained(self, *args, **kwargs): - raise EnvironmentError( - "Save the underlying `StableDiffusionPipeline` and the `ControlNetXSAddon` separately" - " by using `pipe.get_base_pipeline().save_pretrained()` and `pipe.get_controlnet_addon().save_pretrained()`." - ) - def get_base_pipeline(self): """Get underlying `StableDiffusionPipeline` without the `ControlNetXSAddon` model.""" components = {k: v for k, v in self.components.items() if k != "controlnet"} @@ -577,12 +547,12 @@ def check_inputs( # Check `image` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( - self.controlnet, torch._dynamo.eval_frame.OptimizedModule + self.unet, torch._dynamo.eval_frame.OptimizedModule ) if ( - isinstance(self.controlnet, ControlNetXSModel) + isinstance(self.unet, UNetControlNetXSModel) or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetXSModel) + and isinstance(self.unet._orig_mod, UNetControlNetXSModel) ): self.check_image(image, prompt, prompt_embeds) else: @@ -590,9 +560,9 @@ def check_inputs( # Check `controlnet_conditioning_scale` if ( - isinstance(self.controlnet, ControlNetXSModel) + isinstance(self.unet, UNetControlNetXSModel) or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetXSModel) + and isinstance(self.unet._orig_mod, UNetControlNetXSModel) ): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") @@ -845,7 +815,8 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) - controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + # todo umer: what's this for? + controlnet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet # 1. Check inputs. Raise error if not correct self.check_inputs( @@ -903,7 +874,7 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare image - if isinstance(controlnet, ControlNetXSModel): + if isinstance(controlnet, UNetControlNetXSModel): image = self.prepare_image( image=image, width=width, @@ -923,7 +894,7 @@ def __call__( timesteps = self.scheduler.timesteps # 6. Prepare latent variables - num_channels_latents = self.controlnet.base_model.config.in_channels + num_channels_latents = self.unet.base_model.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, @@ -941,7 +912,7 @@ def __call__( # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) - is_controlnet_compiled = is_compiled_module(self.controlnet) + is_controlnet_compiled = is_compiled_module(self.unet) is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -957,7 +928,7 @@ def __call__( do_control = ( i / len(timesteps) >= control_guidance_start and (i + 1) / len(timesteps) <= control_guidance_end ) - noise_pred = self.controlnet( + noise_pred = self.unet( sample=latent_model_input, timestep=t, encoder_hidden_states=prompt_embeds, diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index 72829e084df7..93844633a9c5 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -30,7 +30,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ControlNetXSModel +from ...models import AutoencoderKL, ControlNetXSAddon, UNet2DConditionModel, UNetControlNetXSModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -152,7 +152,8 @@ class StableDiffusionXLControlNetXSPipeline( watermarker is used. """ - model_cpu_offload_seq = "text_encoder->text_encoder_2->controlnet->vae" + # todo: dont load controlnet to gpu + model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" _optional_components = [ "tokenizer", "tokenizer_2", @@ -169,7 +170,8 @@ def __init__( text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, - controlnet: ControlNetXSModel, + unet: Union[UNet2DConditionModel, UNetControlNetXSModel], + controlnet: ControlNetXSAddon, scheduler: KarrasDiffusionSchedulers, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, @@ -177,11 +179,14 @@ def __init__( ): super().__init__() + if isinstance(unet, UNet2DConditionModel): + unet = UNetControlNetXSModel(unet, controlnet) + ( vae_compatible, cnxs_condition_downsample_factor, vae_downsample_factor, - ) = controlnet._check_if_vae_compatible(vae) + ) = unet._check_if_vae_compatible(vae) if not vae_compatible: raise ValueError( f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXSAddon model ({cnxs_condition_downsample_factor}) need to be equal. Consider building the ControlNetXSAddon model with different `conditioning_embedding_out_channels`." @@ -193,6 +198,7 @@ def __init__( text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, + unet=unet, controlnet=controlnet, scheduler=scheduler, feature_extractor=feature_extractor, @@ -238,7 +244,7 @@ def from_pretrained(cls, base_path, controlnet_addon, time_embedding_mix=1.0, ** components = {k: v for k, v in components.items() if k not in ["unet"] + to_ignore} - controlnet = ControlNetXSModel(unet, controlnet_addon, time_embedding_mix) + controlnet = UNetControlNetXSModel(unet, controlnet_addon, time_embedding_mix) return StableDiffusionXLControlNetXSPipeline(controlnet=controlnet, **components) def save_pretrained(self, *args, **kwargs): @@ -487,7 +493,7 @@ def encode_prompt( if self.text_encoder_2 is not None: prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: - prompt_embeds = prompt_embeds.to(dtype=self.controlnet.dtype, device=device) + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method @@ -501,7 +507,7 @@ def encode_prompt( if self.text_encoder_2 is not None: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.controlnet.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -624,12 +630,12 @@ def check_inputs( # Check `image` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( - self.controlnet, torch._dynamo.eval_frame.OptimizedModule + self.unet, torch._dynamo.eval_frame.OptimizedModule ) if ( - isinstance(self.controlnet, ControlNetXSModel) + isinstance(self.unet, UNetControlNetXSModel) or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetXSModel) + and isinstance(self.unet._orig_mod, UNetControlNetXSModel) ): self.check_image(image, prompt, prompt_embeds) else: @@ -637,9 +643,9 @@ def check_inputs( # Check `controlnet_conditioning_scale` if ( - isinstance(self.controlnet, ControlNetXSModel) + isinstance(self.unet, UNetControlNetXSModel) or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetXSModel) + and isinstance(self.unet._orig_mod, UNetControlNetXSModel) ): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") @@ -747,9 +753,9 @@ def _get_add_time_ids( add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( - self.controlnet.base_model.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + self.unet.base_model.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) - expected_add_embed_dim = self.controlnet.base_model.add_embedding.linear_1.in_features + expected_add_embed_dim = self.unet.base_model.add_embedding.linear_1.in_features if expected_add_embed_dim != passed_add_embed_dim: raise ValueError( @@ -985,7 +991,7 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) - controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + controlnet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet # 1. Check inputs. Raise error if not correct self.check_inputs( @@ -1050,7 +1056,7 @@ def __call__( ) # 4. Prepare image - if isinstance(controlnet, ControlNetXSModel): + if isinstance(controlnet, UNetControlNetXSModel): image = self.prepare_image( image=image, width=width, @@ -1070,7 +1076,7 @@ def __call__( timesteps = self.scheduler.timesteps # 6. Prepare latent variables - num_channels_latents = self.controlnet.base_model.config.in_channels + num_channels_latents = self.unet.base_model.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, @@ -1129,7 +1135,7 @@ def __call__( # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) - is_controlnet_compiled = is_compiled_module(self.controlnet) + is_controlnet_compiled = is_compiled_module(self.unet) is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -1147,7 +1153,7 @@ def __call__( do_control = ( i / len(timesteps) >= control_guidance_start and (i + 1) / len(timesteps) <= control_guidance_end ) - noise_pred = self.controlnet( + noise_pred = self.unet( sample=latent_model_input, timestep=t, encoder_hidden_states=prompt_embeds, diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py index 1dbeece8b01e..d2614e023759 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py @@ -14,7 +14,6 @@ # limitations under the License. import gc -import tempfile import traceback import unittest @@ -22,17 +21,14 @@ import torch from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer -import diffusers from diffusers import ( AutoencoderKL, ControlNetXSAddon, - ControlNetXSModel, DDIMScheduler, LCMScheduler, StableDiffusionControlNetXSPipeline, UNet2DConditionModel, ) -from diffusers.utils import logging from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( enable_full_determinism, @@ -57,7 +53,6 @@ PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, - to_np, ) @@ -133,14 +128,13 @@ def get_dummy_components(self, time_cond_proj_dim=None): time_cond_proj_dim=time_cond_proj_dim, ) torch.manual_seed(0) - controlnet_addon = ControlNetXSAddon.from_unet( + controlnet = ControlNetXSAddon.from_unet( base_model=unet, size_ratio=0.5, num_attention_heads=2, learn_time_embedding=True, conditioning_embedding_out_channels=(16, 32), ) - controlnet = ControlNetXSModel(base_model=unet, ctrl_addon=controlnet_addon) torch.manual_seed(0) scheduler = DDIMScheduler( beta_start=0.00085, @@ -175,6 +169,7 @@ def get_dummy_components(self, time_cond_proj_dim=None): tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") components = { + "unet": unet, "controlnet": controlnet, "scheduler": scheduler, "vae": vae, @@ -241,137 +236,6 @@ def test_controlnet_lcm(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - def test_save_load_local(self, expected_max_difference=5e-4): - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - for component in pipe.components.values(): - if hasattr(component, "set_default_attn_processor"): - component.set_default_attn_processor() - - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(torch_device) - output = pipe(**inputs)[0] - - logger = logging.get_logger("diffusers.pipelines.pipeline_utils") - logger.setLevel(diffusers.logging.INFO) - - with tempfile.TemporaryDirectory() as tmpdir_components: - with tempfile.TemporaryDirectory() as tmpdir_addon: - pipe.get_base_pipeline().save_pretrained(tmpdir_components, safe_serialization=False) - pipe.get_controlnet_addon().save_pretrained(tmpdir_addon, safe_serialization=False) - - addon_loaded = ControlNetXSAddon.from_pretrained(tmpdir_addon) - pipe_loaded = self.pipeline_class.from_pretrained( - base_path=tmpdir_components, controlnet_addon=addon_loaded - ) - - for component in pipe_loaded.components.values(): - if hasattr(component, "set_default_attn_processor"): - component.set_default_attn_processor() - - pipe_loaded.to(torch_device) - pipe_loaded.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(torch_device) - output_loaded = pipe_loaded(**inputs)[0] - - max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() - self.assertLess(max_diff, expected_max_difference) - - def test_save_load_optional_components(self, expected_max_difference=1e-4): - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - for component in pipe.components.values(): - if hasattr(component, "set_default_attn_processor"): - component.set_default_attn_processor() - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - # set all optional components to None - for optional_component in pipe._optional_components: - setattr(pipe, optional_component, None) - - generator_device = "cpu" - inputs = self.get_dummy_inputs(generator_device) - output = pipe(**inputs)[0] - - with tempfile.TemporaryDirectory() as tmpdir_components: - with tempfile.TemporaryDirectory() as tmpdir_addon: - pipe.get_base_pipeline().save_pretrained(tmpdir_components, safe_serialization=False) - pipe.get_controlnet_addon().save_pretrained(tmpdir_addon, safe_serialization=False) - - addon_loaded = ControlNetXSAddon.from_pretrained(tmpdir_addon) - pipe_loaded = self.pipeline_class.from_pretrained( - base_path=tmpdir_components, controlnet_addon=addon_loaded - ) - - for component in pipe_loaded.components.values(): - if hasattr(component, "set_default_attn_processor"): - component.set_default_attn_processor() - pipe_loaded.to(torch_device) - pipe_loaded.set_progress_bar_config(disable=None) - - for optional_component in pipe._optional_components: - self.assertTrue( - getattr(pipe_loaded, optional_component) is None, - f"`{optional_component}` did not stay set to None after loading.", - ) - - inputs = self.get_dummy_inputs(generator_device) - output_loaded = pipe_loaded(**inputs)[0] - - max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() - self.assertLess(max_diff, expected_max_difference) - - @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") - def test_save_load_float16(self, expected_max_diff=1e-2): - components = self.get_dummy_components() - for name, module in components.items(): - if hasattr(module, "half"): - components[name] = module.to(torch_device).half() - - pipe = self.pipeline_class(**components) - for component in pipe.components.values(): - if hasattr(component, "set_default_attn_processor"): - component.set_default_attn_processor() - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(torch_device) - output = pipe(**inputs)[0] - - with tempfile.TemporaryDirectory() as tmpdir_components: - with tempfile.TemporaryDirectory() as tmpdir_addon: - pipe.save_pretrained( - base_path=tmpdir_components, - addon_path=tmpdir_addon, - base_kwargs={"safe_serialization": False}, - addon_kwargs={"safe_serialization": False}, - ) - - pipe_loaded = self.pipeline_class.from_pretrained(base_path=tmpdir_components, addon_path=tmpdir_addon) - for component in pipe_loaded.components.values(): - if hasattr(component, "set_default_attn_processor"): - component.set_default_attn_processor() - pipe_loaded.to(torch_device) - pipe_loaded.set_progress_bar_config(disable=None) - - for name, component in pipe_loaded.components.items(): - if hasattr(component, "dtype"): - self.assertTrue( - component.dtype == torch.float16, - f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.", - ) - - inputs = self.get_dummy_inputs(torch_device) - output_loaded = pipe_loaded(**inputs)[0] - max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() - self.assertLess( - max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading." - ) - @slow @require_torch_gpu diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py index b2bea4c6810d..e08db91d695f 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py @@ -25,7 +25,6 @@ from diffusers import ( AutoencoderKL, ControlNetXSAddon, - ControlNetXSModel, EulerDiscreteScheduler, StableDiffusionXLControlNetXSPipeline, UNet2DConditionModel, @@ -88,13 +87,12 @@ def get_dummy_components(self): cross_attention_dim=64, ) torch.manual_seed(0) - controlnet_addon = ControlNetXSAddon.from_unet( + controlnet = ControlNetXSAddon.from_unet( base_model=unet, size_ratio=0.5, learn_time_embedding=True, conditioning_embedding_out_channels=(16, 32), ) - controlnet = ControlNetXSModel(base_model=unet, ctrl_addon=controlnet_addon) torch.manual_seed(0) scheduler = EulerDiscreteScheduler( beta_start=0.00085, @@ -134,6 +132,7 @@ def get_dummy_components(self): tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") components = { + "unet": unet, "controlnet": controlnet, "scheduler": scheduler, "vae": vae, @@ -309,168 +308,6 @@ def test_stable_diffusion_xl_prompt_embeds(self): # make sure that it's equal assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1.1e-4 - # copied from test_controlnetxs.py - def test_save_load_local(self, expected_max_difference=5e-4): - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - for component in pipe.components.values(): - if hasattr(component, "set_default_attn_processor"): - component.set_default_attn_processor() - - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(torch_device) - output = pipe(**inputs)[0] - - logger = logging.get_logger("diffusers.pipelines.pipeline_utils") - logger.setLevel(diffusers.logging.INFO) - - with tempfile.TemporaryDirectory() as tmpdir_components: - with tempfile.TemporaryDirectory() as tmpdir_addon: - pipe.get_base_pipeline().save_pretrained(tmpdir_components, safe_serialization=False) - pipe.get_controlnet_addon().save_pretrained(tmpdir_addon, safe_serialization=False) - - addon_loaded = ControlNetXSAddon.from_pretrained(tmpdir_addon) - pipe_loaded = self.pipeline_class.from_pretrained( - base_path=tmpdir_components, controlnet_addon=addon_loaded - ) - - for component in pipe_loaded.components.values(): - if hasattr(component, "set_default_attn_processor"): - component.set_default_attn_processor() - - pipe_loaded.to(torch_device) - pipe_loaded.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(torch_device) - output_loaded = pipe_loaded(**inputs)[0] - - max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() - self.assertLess(max_diff, expected_max_difference) - - def test_save_load_optional_components(self, expected_max_difference=1e-4): - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - - # set all optional components to None - for optional_component in pipe._optional_components: - setattr(pipe, optional_component, None) - - for component in pipe.components.values(): - if hasattr(component, "set_default_attn_processor"): - component.set_default_attn_processor() - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - generator_device = "cpu" - inputs = self.get_dummy_inputs(generator_device) - - tokenizer = components.pop("tokenizer") - tokenizer_2 = components.pop("tokenizer_2") - text_encoder = components.pop("text_encoder") - text_encoder_2 = components.pop("text_encoder_2") - - tokenizers = [tokenizer, tokenizer_2] if tokenizer is not None else [tokenizer_2] - text_encoders = [text_encoder, text_encoder_2] if text_encoder is not None else [text_encoder_2] - prompt = inputs.pop("prompt") - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = self.encode_prompt(tokenizers, text_encoders, prompt) - inputs["prompt_embeds"] = prompt_embeds - inputs["negative_prompt_embeds"] = negative_prompt_embeds - inputs["pooled_prompt_embeds"] = pooled_prompt_embeds - inputs["negative_pooled_prompt_embeds"] = negative_pooled_prompt_embeds - - output = pipe(**inputs)[0] - - with tempfile.TemporaryDirectory() as tmpdir_components: - with tempfile.TemporaryDirectory() as tmpdir_addon: - pipe.get_base_pipeline().save_pretrained(tmpdir_components, safe_serialization=False) - pipe.get_controlnet_addon().save_pretrained(tmpdir_addon, safe_serialization=False) - - addon_loaded = ControlNetXSAddon.from_pretrained(tmpdir_addon) - pipe_loaded = self.pipeline_class.from_pretrained( - base_path=tmpdir_components, controlnet_addon=addon_loaded - ) - - for component in pipe_loaded.components.values(): - if hasattr(component, "set_default_attn_processor"): - component.set_default_attn_processor() - - pipe_loaded.to(torch_device) - pipe_loaded.set_progress_bar_config(disable=None) - - for optional_component in pipe._optional_components: - self.assertTrue( - getattr(pipe_loaded, optional_component) is None, - f"`{optional_component}` did not stay set to None after loading.", - ) - - inputs = self.get_dummy_inputs(generator_device) - - _ = inputs.pop("prompt") - inputs["prompt_embeds"] = prompt_embeds - inputs["negative_prompt_embeds"] = negative_prompt_embeds - inputs["pooled_prompt_embeds"] = pooled_prompt_embeds - inputs["negative_pooled_prompt_embeds"] = negative_pooled_prompt_embeds - - output_loaded = pipe_loaded(**inputs)[0] - - max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() - self.assertLess(max_diff, expected_max_difference) - - # copied from test_controlnetxs.py - @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") - def test_save_load_float16(self, expected_max_diff=1e-2): - components = self.get_dummy_components() - for name, module in components.items(): - if hasattr(module, "half"): - components[name] = module.to(torch_device).half() - - pipe = self.pipeline_class(**components) - for component in pipe.components.values(): - if hasattr(component, "set_default_attn_processor"): - component.set_default_attn_processor() - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(torch_device) - output = pipe(**inputs)[0] - - with tempfile.TemporaryDirectory() as tmpdir_components: - with tempfile.TemporaryDirectory() as tmpdir_addon: - pipe.save_pretrained( - base_path=tmpdir_components, - addon_path=tmpdir_addon, - base_kwargs={"safe_serialization": False}, - addon_kwargs={"safe_serialization": False}, - ) - - pipe_loaded = self.pipeline_class.from_pretrained(base_path=tmpdir_components, addon_path=tmpdir_addon) - for component in pipe_loaded.components.values(): - if hasattr(component, "set_default_attn_processor"): - component.set_default_attn_processor() - pipe_loaded.to(torch_device) - pipe_loaded.set_progress_bar_config(disable=None) - - for name, component in pipe_loaded.components.items(): - if hasattr(component, "dtype"): - self.assertTrue( - component.dtype == torch.float16, - f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.", - ) - - inputs = self.get_dummy_inputs(torch_device) - output_loaded = pipe_loaded(**inputs)[0] - max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() - self.assertLess( - max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading." - ) - @slow @require_torch_gpu From 5618c95c8a43519f96a626bb97f18e241a62a673 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Tue, 12 Mar 2024 20:39:02 +0100 Subject: [PATCH 48/75] CheckIn Mar 12 '24 --- src/diffusers/models/controlnet_xs.py | 372 +++++++++--------- .../controlnet_xs/pipeline_controlnet_xs.py | 4 +- .../pipeline_controlnet_xs_sd_xl.py | 44 +-- 3 files changed, 184 insertions(+), 236 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 515ac8581aab..bcc650743667 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -115,6 +115,8 @@ class ControlNetXSAddon(ModelMixin, ConfigMixin): Dimension of input into time embedding. Needs to be same as in the base model. time_embedding_dim (`int`, defaults to 1280): Dimension of output from time embedding. Needs to be same as in the base model. + time_embedding_mix + # todo umer learn_time_embedding (`bool`, defaults to `False`): Whether a time embedding should be learned. If yes, `ControlNetXSModel` will combine the time embeddings of the base model and the addon. If no, `ControlNetXSModel` will use the base model's time embedding. @@ -264,6 +266,7 @@ def __init__( conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), time_embedding_input_dim: Optional[int] = 320, time_embedding_dim: Optional[int] = 1280, + time_embedding_mix: float = 1.0, learn_time_embedding: bool = False, channels_base: Dict[str, List[Tuple[int]]] = { "down - out": [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280], @@ -460,10 +463,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): It's default parameters are compatible with StableDiffusion. Parameters: - base_model (`UNet2DConditionModel`): - The base UNet to control. - ctrl_addon (`ControlNetXSAddon`): - The control addon. + # todo umer time_embedding_mix (`float`, defaults to 1.0): If 0, then only the base model's time embedding is used. If 1, then only the control model's time embedding is used. @@ -474,36 +474,6 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): def __init__( self, # unet configs - conditioning_channels: int = 3, - conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), - time_embedding_input_dim: int = 320, - time_embedding_dim: int = 1280, - time_embedding_mix: float = 1.0, - base_model_channel_sizes: Dict[str, List[Tuple[int]]] = { - "down": [ - (4, 320), - (320, 320), - (320, 320), - (320, 320), - (320, 640), - (640, 640), - (640, 640), - (640, 1280), - (1280, 1280), - ], - "mid": [(1280, 1280)], - "up": [ - (2560, 1280), - (2560, 1280), - (1920, 1280), - (1920, 640), - (1280, 640), - (960, 640), - (960, 320), - (640, 320), - (640, 320), - ], - }, sample_size: Optional[int] = 96, down_block_types: Tuple[str] = ( "CrossAttnDownBlock2D", @@ -512,32 +482,37 @@ def __init__( "DownBlock2D", ), up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), # for addon: (4, 8, 16, 16) + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), norm_num_groups: Optional[int] = 32, cross_attention_dim: Union[int, Tuple[int]] = 1024, transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, # type Tuple[Tuple] necessary? num_attention_heads: Optional[Union[int, Tuple[int]]] = 8, upcast_attention: bool = True, - # controlnet configs - controlnet_conditioning_channel_order: str = "rgb", - conditioning_learn_time_embedding: bool = False, - channels_base: Dict[str, List[Tuple[int]]] = { - "down - out": [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280], - "mid - out": 1280, - "up - in": [1280, 1280, 1280, 1280, 1280, 1280, 1280, 640, 640, 640, 320, 320], - }, - attention_head_dim: Union[int, Tuple[int]] = 4, - max_norm_num_groups: int = 32, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + time_embedding_dim: Optional[int] = None, + # additional controlnet configs + time_embedding_mix: float = 1.0, + ctrl_conditioning_channels: int = 3, + ctrl_conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), + ctrl_time_embedding_input_dim: int = 320, + ctrl_conditioning_channel_order: str = "rgb", + ctrl_learn_time_embedding: bool = False, + ctrl_block_out_channels: Tuple[int] = (4, 8, 16, 16), + ctrl_attention_head_dim: Union[int, Tuple[int]] = 4, + ctrl_max_norm_num_groups: int = 32, ): super().__init__() if time_embedding_mix < 0 or time_embedding_mix > 1: raise ValueError("`time_embedding_mix` needs to be between 0 and 1.") - if time_embedding_mix < 1 and not conditioning_learn_time_embedding: + if time_embedding_mix < 1 and not ctrl_learn_time_embedding: raise ValueError( "To use `time_embedding_mix` < 1, initialize `ctrl_addon` with `learn_time_embedding = True`" ) + time_embedding_dim = time_embedding_dim or block_out_channels[0] * 4 + # Create UNet and decompose it into subblocks, which we then save base_model = UNet2DConditionModel( sample_size=sample_size, @@ -551,17 +526,57 @@ def __init__( use_linear_projection=True, upcast_attention=upcast_attention, time_embedding_dim=time_embedding_dim, + class_embed_type=class_embed_type, + addition_embed_type=addition_embed_type, ) - self.base_down_subblocks = nn.ModuleList() - self.base_up_subblocks = nn.ModuleList() + self.in_channels = 4 + + self.base_time_proj = base_model.time_proj + self.base_time_embedding = base_model.time_embedding + self.base_class_embedding = base_model.class_embedding + self.base_add_time_proj = base_model.add_time_proj if hasattr(base_model, 'add_time_proj') else None + self.base_add_embedding = base_model.add_embedding if hasattr(base_model, 'add_embedding') else None + + self.base_conv_in = base_model.conv_in + self.base_mid_block = base_model.mid_block + self.base_conv_norm_out = base_model.conv_norm_out + self.base_conv_act = base_model.conv_act + self.base_conv_out = base_model.conv_out + + self.base_down_subblocks, self.base_up_subblocks = UNetControlNetXSModel._unet_to_subblocks(base_model) + + self.control_addon = ControlNetXSAddon( + conditioning_channels=ctrl_conditioning_channels, + conditioning_channel_order=ctrl_conditioning_channel_order, + conditioning_embedding_out_channels=ctrl_conditioning_embedding_out_channels, + time_embedding_input_dim=ctrl_time_embedding_input_dim, + time_embedding_dim=time_embedding_dim, + time_embedding_mix=time_embedding_mix, + learn_time_embedding=ctrl_learn_time_embedding, + channels_base=ControlNetXSAddon.gather_base_subblock_sizes(block_out_channels), + attention_head_dim=ctrl_attention_head_dim, + block_out_channels=ctrl_block_out_channels, + cross_attention_dim=cross_attention_dim, + down_block_types=down_block_types, + sample_size=sample_size, + transformer_layers_per_block=transformer_layers_per_block, + upcast_attention=upcast_attention, + max_norm_num_groups=ctrl_max_norm_num_groups, + ) + + @classmethod + def _unet_to_subblocks(cls, unet: UNet2DConditionModel): + """todo umer""" + down_subblocks = nn.ModuleList() + up_subblocks = nn.ModuleList() - for block in base_model.down_blocks: + for block in unet.down_blocks: # Each ResNet / Attention pair is a subblock resnets = block.resnets attentions = block.attentions if hasattr(block, "attentions") else [None] * len(resnets) for r, a in zip(resnets, attentions): - self.base_down_subblocks.append(CrossAttnSubBlock2D.from_modules(r, a)) + down_subblocks.append(CrossAttnSubBlock2D.from_modules(r, a)) # Each Downsampler is a subblock if block.downsamplers is not None: if len(block.downsamplers) != 1: @@ -569,9 +584,9 @@ def __init__( "ControlNet-XS currently only supports StableDiffusion and StableDiffusion-XL." "Therefore each down block of the base model should have only 1 downsampler (if any)." ) - self.base_down_subblocks.append(DownSubBlock2D.from_modules(block.downsamplers[0])) + down_subblocks.append(DownSubBlock2D.from_modules(block.downsamplers[0])) - for block in base_model.up_blocks: + for block in unet.up_blocks: # Each ResNet / Attention / Upsampler triple is a subblock if block.upsamplers is not None: if len(block.upsamplers) != 1: @@ -587,29 +602,10 @@ def __init__( attentions = block.attentions if hasattr(block, "attentions") else [None] * len(resnets) upsamplers = [None] * (len(resnets) - 1) + [upsampler] for r, a, u in zip(resnets, attentions, upsamplers): - self.base_up_subblocks.append(CrossAttnUpSubBlock2D.from_modules(r, a, u)) + up_subblocks.append(CrossAttnUpSubBlock2D.from_modules(r, a, u)) - self.control_addon = ControlNetXSAddon( - conditioning_channels=conditioning_channels, - conditioning_channel_order=controlnet_conditioning_channel_order, - conditioning_embedding_out_channels=conditioning_embedding_out_channels, - time_embedding_input_dim=time_embedding_input_dim, - time_embedding_dim=time_embedding_dim, - learn_time_embedding=conditioning_learn_time_embedding, - channels_base=channels_base, - attention_head_dim=attention_head_dim, - block_out_channels=block_out_channels, - cross_attention_dim=cross_attention_dim, - down_block_types=down_block_types, - sample_size=sample_size, - transformer_layers_per_block=transformer_layers_per_block, - upcast_attention=upcast_attention, - max_norm_num_groups=max_norm_num_groups - ) - - self.time_embedding_mix = time_embedding_mix + return down_subblocks, up_subblocks - # todo umer @classmethod def from_unet2d( cls, @@ -617,110 +613,98 @@ def from_unet2d( controlnet: ControlNetXSAddon, load_weights: bool = True, ): - # analogous to diffusers.models.unets.unet_motion_model.UNetMotionModel.from_unet2d - config = unet.config - config["_class_name"] = cls.__name__ - - down_blocks = [] - for down_blocks_type in config["down_block_types"]: - if "CrossAttn" in down_blocks_type: - down_blocks.append("CrossAttnDownBlockMotion") - else: - down_blocks.append("DownBlockMotion") - config["down_block_types"] = down_blocks - - up_blocks = [] - for down_blocks_type in config["up_block_types"]: - if "CrossAttn" in down_blocks_type: - up_blocks.append("CrossAttnUpBlockMotion") - else: - up_blocks.append("UpBlockMotion") - - config["up_block_types"] = up_blocks - - if has_motion_adapter: - config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"] - config["motion_max_seq_length"] = motion_adapter.config["motion_max_seq_length"] - config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"] + # todo umer: assert unet is sd/sdxl? - # For PIA UNets we need to set the number input channels to 9 - if motion_adapter.config["conv_in_channels"]: - config["in_channels"] = motion_adapter.config["conv_in_channels"] + # Create config for UNetControlNetXSModel object + config = {} + config["_class_name"] = cls.__name__ - # Need this for backwards compatibility with UNet2DConditionModel checkpoints - if not config.get("num_attention_heads"): - config["num_attention_heads"] = config["attention_head_dim"] + params_for_unet = [ + "time_embedding_dim", + "sample_size", + "down_block_types", + "up_block_types", + "block_out_channels", + "norm_num_groups", + "cross_attention_dim", + "transformer_layers_per_block", + "upcast_attention", + "class_embed_type", + "addition_embed_type", + ] + config.update({k:v for k,v in unet.config.items() if k in params_for_unet}) + # The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. + config["num_attention_heads"] = unet.config.attention_head_dim + + params_for_controlnet = [ + "conditioning_channels", + "conditioning_embedding_out_channels", + "conditioning_channel_order", + "time_embedding_input_dim", + "learn_time_embedding", + "block_out_channels", + "attention_head_dim", + "max_norm_num_groups" + ] + config.update({"ctrl_"+k:v for k,v in controlnet.config.items() if k in params_for_controlnet}) model = cls.from_config(config) if not load_weights: return model - # Logic for loading PIA UNets which allow the first 4 channels to be any UNet2DConditionModel conv_in weight - # while the last 5 channels must be PIA conv_in weights. - if has_motion_adapter and motion_adapter.config["conv_in_channels"]: - model.conv_in = motion_adapter.conv_in - updated_conv_in_weight = torch.cat( - [unet.conv_in.weight, motion_adapter.conv_in.weight[:, 4:, :, :]], dim=1 - ) - model.conv_in.load_state_dict({"weight": updated_conv_in_weight, "bias": unet.conv_in.bias}) - else: - model.conv_in.load_state_dict(unet.conv_in.state_dict()) - - model.time_proj.load_state_dict(unet.time_proj.state_dict()) - model.time_embedding.load_state_dict(unet.time_embedding.state_dict()) - - for i, down_block in enumerate(unet.down_blocks): - model.down_blocks[i].resnets.load_state_dict(down_block.resnets.state_dict()) - if hasattr(model.down_blocks[i], "attentions"): - model.down_blocks[i].attentions.load_state_dict(down_block.attentions.state_dict()) - if model.down_blocks[i].downsamplers: - model.down_blocks[i].downsamplers.load_state_dict(down_block.downsamplers.state_dict()) - - for i, up_block in enumerate(unet.up_blocks): - model.up_blocks[i].resnets.load_state_dict(up_block.resnets.state_dict()) - if hasattr(model.up_blocks[i], "attentions"): - model.up_blocks[i].attentions.load_state_dict(up_block.attentions.state_dict()) - if model.up_blocks[i].upsamplers: - model.up_blocks[i].upsamplers.load_state_dict(up_block.upsamplers.state_dict()) - - model.mid_block.resnets.load_state_dict(unet.mid_block.resnets.state_dict()) - model.mid_block.attentions.load_state_dict(unet.mid_block.attentions.state_dict()) - - if unet.conv_norm_out is not None: - model.conv_norm_out.load_state_dict(unet.conv_norm_out.state_dict()) - if unet.conv_act is not None: - model.conv_act.load_state_dict(unet.conv_act.state_dict()) - model.conv_out.load_state_dict(unet.conv_out.state_dict()) - - if has_motion_adapter: - model.load_motion_modules(motion_adapter) - - # ensure that the Motion UNet is the same dtype as the UNet2DConditionModel + # Load params + modules_from_unet = [ + "time_proj", + "time_embedding", + "conv_in", + "mid_block", + "conv_norm_out", + "conv_act", + "conv_out" + ] + for m in modules_from_unet: + getattr(model, "base_" + m).load_state_dict(getattr(unet, m).state_dict()) + + optional_modules_from_unet = ["class_embedding"] + for m in optional_modules_from_unet: + module = getattr(model, "base_" + m) + if module is not None: + module.load_state_dict(getattr(unet, m).state_dict()) + + sdxl_specific_modules_from_unet = [ + "add_time_proj", + "add_embedding", + ] + if hasattr(unet, sdxl_specific_modules_from_unet[0]): + # if the UNet has any of the sdxl-specific components, it is an sdxl and has all of them + for m in sdxl_specific_modules_from_unet: + getattr(model, "base_" + m).load_state_dict(getattr(unet, m).state_dict()) + + model.base_down_subblocks, model.base_up_subblocks = UNetControlNetXSModel._unet_to_subblocks(unet) + + model.control_addon.load_state_dict(controlnet.state_dict()) + + # ensure that the UNetControlNetXSModel is the same dtype as the UNet2DConditionModel model.to(unet.dtype) return model + def freeze_unet2d_params(self) -> None: + """Freeze the weights of just the UNet2DConditionModel, and leave the ControlNetXSAddon + unfrozen for fine tuning. + """ + # Freeze everything + for param in self.parameters(): + param.requires_grad = False - # todo umer - def load_controlnet_addon(self, controlnet: ControlNetXSAddon) -> None: - pass - - # todo umer - def save_controlnet_addon( - self, - save_directory: str, - is_main_process: bool = True, - safe_serialization: bool = True, - variant: Optional[str] = None, - push_to_hub: bool = False, - **kwargs, - ) -> None: - pass + # Unfreeze ControlNetXSAddon + for param in self.control_addon.parameters(): + param.requires_grad = True @torch.no_grad() def _check_if_vae_compatible(self, vae: AutoencoderKL): - condition_downscale_factor = 2 ** (len(self.ctrl_addon.config.conditioning_embedding_out_channels) - 1) + condition_downscale_factor = 2 ** (len(self.control_addon.config.conditioning_embedding_out_channels) - 1) vae_downscale_factor = 2 ** (len(vae.config.block_out_channels) - 1) compatible = condition_downscale_factor == vae_downscale_factor return compatible, condition_downscale_factor, vae_downscale_factor @@ -780,7 +764,7 @@ def forward( """ # check channel order - if self.ctrl_addon.config.conditioning_channel_order == "bgr": + if self.control_addon.config.conditioning_channel_order == "bgr": controlnet_cond = torch.flip(controlnet_cond, dims=[1]) # prepare attention_mask @@ -805,38 +789,38 @@ def forward( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) - t_emb = self.base_model.time_proj(timesteps) + t_emb = self.base_time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) - if self.ctrl_addon.config.learn_time_embedding: - ctrl_temb = self.ctrl_addon.time_embedding(t_emb, timestep_cond) - base_temb = self.base_model.time_embedding(t_emb, timestep_cond) - interpolation_param = self.time_embedding_mix**0.3 + if self.config.ctrl_learn_time_embedding: + ctrl_temb = self.control_addon.time_embedding(t_emb, timestep_cond) + base_temb = self.base_time_embedding(t_emb, timestep_cond) + interpolation_param = self.control_addon.time_embedding_mix**0.3 temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param) else: - temb = self.base_model.time_embedding(t_emb) + temb = self.base_time_embedding(t_emb) # added time & text embeddings aug_emb = None - if self.base_model.class_embedding is not None: + if self.base_class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") - if self.base_model.config.class_embed_type == "timestep": + if self.config.class_embed_type == "timestep": class_labels = self.base_time_proj(class_labels) - class_emb = self.base_model.class_embedding(class_labels).to(dtype=self.dtype) + class_emb = self.base_class_embedding(class_labels).to(dtype=self.dtype) temb = temb + class_emb - if self.base_model.config.addition_embed_type is None: + if self.config.addition_embed_type is None: pass - elif self.base_model.config.addition_embed_type == "text_time": + elif self.config.addition_embed_type == "text_time": # SDXL - style if "text_embeds" not in added_cond_kwargs: raise ValueError( @@ -848,14 +832,14 @@ def forward( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") - time_embeds = self.base_model.add_time_proj(time_ids.flatten()) + time_embeds = self.base_add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(temb.dtype) - aug_emb = self.base_model.add_embedding(add_embeds) + aug_emb = self.base_add_embedding(add_embeds) else: raise ValueError( - f"ControlNet-XS currently only supports StableDiffusion and StableDiffusion-XL, so addition_embed_type = {self.base_model.config.addition_embed_type} is currently not supported." + f"ControlNet-XS currently only supports StableDiffusion and StableDiffusion-XL, so addition_embed_type = {self.config.addition_embed_type} is currently not supported." ) temb = temb + aug_emb if aug_emb is not None else temb @@ -864,7 +848,7 @@ def forward( cemb = encoder_hidden_states # Preparation - guided_hint = self.ctrl_addon.controlnet_cond_embedding(controlnet_cond) + guided_hint = self.control_addon.controlnet_cond_embedding(controlnet_cond) h_ctrl = h_base = sample hs_base, hs_ctrl = [], [] @@ -872,18 +856,18 @@ def forward( # Cross Control # Let's first define variables to shorten notation base_down_subblocks = self.base_down_subblocks - ctrl_down_subblocks = self.ctrl_addon.down_subblocks + ctrl_down_subblocks = self.control_addon.down_subblocks - down_zero_convs_b2c = self.ctrl_addon.down_zero_convs_b2c - down_zero_convs_c2b = self.ctrl_addon.down_zero_convs_c2b - mid_zero_convs_c2b = self.ctrl_addon.mid_zero_convs_c2b - up_zero_convs_c2b = self.ctrl_addon.up_zero_convs_c2b + down_zero_convs_b2c = self.control_addon.down_zero_convs_b2c + down_zero_convs_c2b = self.control_addon.down_zero_convs_c2b + mid_zero_convs_c2b = self.control_addon.mid_zero_convs_c2b + up_zero_convs_c2b = self.control_addon.up_zero_convs_c2b if not do_control: # Run the base model without control # 1 - conv in & down - h_base = self.base_model.conv_in(h_base) + h_base = self.base_conv_in(h_base) hs_base.append(h_base) for b in base_down_subblocks: @@ -895,16 +879,16 @@ def forward( hs_base.append(h_base) # 2 - mid - h_base = self.base_model.mid_block(h_base, temb, cemb, attention_mask, cross_attention_kwargs) + h_base = self.base_mid_block(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # 3 - up for b, skip_b in zip(self.base_up_subblocks, reversed(hs_base)): h_base = torch.cat([h_base, skip_b], dim=1) # concat info from base encoder h_base = b(h_base, temb, cemb, attention_mask, cross_attention_kwargs) - h_base = self.base_model.conv_norm_out(h_base) - h_base = self.base_model.conv_act(h_base) - h_base = self.base_model.conv_out(h_base) + h_base = self.base_conv_norm_out(h_base) + h_base = self.base_conv_act(h_base) + h_base = self.base_conv_out(h_base) if not return_dict: return h_base @@ -917,8 +901,8 @@ def forward( # ctrl -> base: conv_in | subblock 1 | ... | subblock n # base -> ctrl: | subblock 1 | ... | subblock n | mid block - h_base = self.base_model.conv_in(h_base) - h_ctrl = self.ctrl_addon.conv_in(h_ctrl) + h_base = self.base_conv_in(h_base) + h_ctrl = self.control_addon.conv_in(h_ctrl) if guided_hint is not None: h_ctrl += guided_hint h_base = h_base + down_zero_convs_c2b[0](h_ctrl) * conditioning_scale # add ctrl -> base @@ -947,10 +931,10 @@ def forward( h_ctrl = torch.cat([h_ctrl, down_zero_convs_b2c[-1](h_base)], dim=1) # concat base -> ctrl # 2 - mid - h_base = self.base_model.mid_block( + h_base = self.base_mid_block( h_base, temb, cemb, attention_mask, cross_attention_kwargs ) # apply base subblock - h_ctrl = self.ctrl_addon.mid_block( + h_ctrl = self.control_addon.mid_block( h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs ) # apply ctrl subblock h_base = h_base + mid_zero_convs_c2b(h_ctrl) * conditioning_scale # add ctrl -> base @@ -964,9 +948,9 @@ def forward( h_base = b(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # 4 - conv out - h_base = self.base_model.conv_norm_out(h_base) - h_base = self.base_model.conv_act(h_base) - h_base = self.base_model.conv_out(h_base) + h_base = self.base_conv_norm_out(h_base) + h_base = self.base_conv_act(h_base) + h_base = self.base_conv_out(h_base) if not return_dict: return h_base diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index 568eb8a93286..5aabc578392c 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -146,7 +146,7 @@ def __init__( super().__init__() if isinstance(unet, UNet2DConditionModel): - unet = UNetControlNetXSModel(unet, controlnet) + unet = UNetControlNetXSModel.from_unet2d(unet, controlnet) if safety_checker is None and requires_safety_checker: logger.warning( @@ -894,7 +894,7 @@ def __call__( timesteps = self.scheduler.timesteps # 6. Prepare latent variables - num_channels_latents = self.unet.base_model.config.in_channels + num_channels_latents = self.unet.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index 93844633a9c5..b0b87c718acf 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -180,7 +180,7 @@ def __init__( super().__init__() if isinstance(unet, UNet2DConditionModel): - unet = UNetControlNetXSModel(unet, controlnet) + unet = UNetControlNetXSModel.from_unet2d(unet, controlnet) ( vae_compatible, @@ -217,42 +217,6 @@ def __init__( self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) - @classmethod - def from_pretrained(cls, base_path, controlnet_addon, time_embedding_mix=1.0, **kwargs): - """ - Instantiates pipeline from a `StableDiffusionXLPipeline` and a `ControlNetXSAddon`. - - Arguments: - base_path (`str` or `os.PathLike`): - Directory to load underlying `StableDiffusionXLPipeline` from. - controlnet_addon (`ControlNetXSAddon`): - A `ControlNetXSAddon` model. - kwargs (`Dict[str, Any]`, *optional*): - Additional keyword arguments passed along to the [`~StableDiffusionXLPipeline.from_pretrained`] method. - """ - - components = StableDiffusionXLPipeline.from_pretrained(base_path, **kwargs).components - - unet = components["unet"] - - to_ignore = ["image_encoder"] - for item in to_ignore: - if item in components: - print( - f"Loaded base pipeline has component `{item}` which StableDiffusionControlNetXSPipeline can't use. It will be ignored." - ) - - components = {k: v for k, v in components.items() if k not in ["unet"] + to_ignore} - - controlnet = UNetControlNetXSModel(unet, controlnet_addon, time_embedding_mix) - return StableDiffusionXLControlNetXSPipeline(controlnet=controlnet, **components) - - def save_pretrained(self, *args, **kwargs): - raise EnvironmentError( - "Save the underlying `StableDiffusionXLPipeline` and the `ControlNetXSAddon` separately" - " by using `pipe.get_base_pipeline().save_pretrained()` and `pipe.get_controlnet_addon().save_pretrained()`." - ) - def get_base_pipeline(self): """Get underlying `StableDiffusionXLPipeline` without the `ControlNetXSAddon` model.""" components = {k: v for k, v in self.components.items() if k != "controlnet"} @@ -753,9 +717,9 @@ def _get_add_time_ids( add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( - self.unet.base_model.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) - expected_add_embed_dim = self.unet.base_model.add_embedding.linear_1.in_features + expected_add_embed_dim = self.base_add_embedding.linear_1.in_features if expected_add_embed_dim != passed_add_embed_dim: raise ValueError( @@ -1076,7 +1040,7 @@ def __call__( timesteps = self.scheduler.timesteps # 6. Prepare latent variables - num_channels_latents = self.unet.base_model.config.in_channels + num_channels_latents = self.unet.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, From da96576ed4e8bb8cee12e3b5c3460e487d86fc23 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Wed, 13 Mar 2024 18:41:27 +0100 Subject: [PATCH 49/75] Fixed tests for SD --- Pipfile | 11 ++ src/diffusers/models/controlnet_xs.py | 51 ++--- .../unets/test_models_unet_controlnetxs.py | 182 ++++++++++++++++++ .../controlnet_xs/test_controlnetxs.py | 57 ++++++ .../controlnet_xs/test_controlnetxs_sdxl.py | 63 +++++- 5 files changed, 337 insertions(+), 27 deletions(-) create mode 100644 Pipfile create mode 100644 tests/models/unets/test_models_unet_controlnetxs.py diff --git a/Pipfile b/Pipfile new file mode 100644 index 000000000000..0757494bb360 --- /dev/null +++ b/Pipfile @@ -0,0 +1,11 @@ +[[source]] +url = "https://pypi.org/simple" +verify_ssl = true +name = "pypi" + +[packages] + +[dev-packages] + +[requires] +python_version = "3.11" diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index bcc650743667..608e98dd4847 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -227,7 +227,7 @@ def from_unet( relative_size = size_ratio is not None if not (fixed_size ^ relative_size): raise ValueError( - "Pass exactly one of `block_out_channels` (for absolute sizing) or `control_model_ratio` (for relative sizing)." + "Pass exactly one of `block_out_channels` (for absolute sizing) or `size_ratio` (for relative sizing)." ) channels_base = ControlNetXSAddon.gather_base_subblock_sizes(base_model.config.block_out_channels) @@ -351,7 +351,7 @@ def __init__( use_crossattention = down_block_type == "CrossAttnDownBlock2D" self.down_subblocks.append( - CrossAttnSubBlock2D( + CrossAttnDownSubBlock2D( has_crossattn=use_crossattention, in_channels=input_channel + channels_base["down - out"][subblock_counter], out_channels=output_channel, @@ -365,7 +365,7 @@ def __init__( ) subblock_counter += 1 self.down_subblocks.append( - CrossAttnSubBlock2D( + CrossAttnDownSubBlock2D( has_crossattn=use_crossattention, in_channels=output_channel + channels_base["down - out"][subblock_counter], out_channels=output_channel, @@ -470,6 +470,8 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): Otherwise, both are combined. """ + _supports_gradient_checkpointing = True + @register_to_config def __init__( self, @@ -485,12 +487,13 @@ def __init__( block_out_channels: Tuple[int] = (320, 640, 1280, 1280), norm_num_groups: Optional[int] = 32, cross_attention_dim: Union[int, Tuple[int]] = 1024, - transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, # type Tuple[Tuple] necessary? + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, # type Tuple[Tuple] necessary? num_attention_heads: Optional[Union[int, Tuple[int]]] = 8, upcast_attention: bool = True, class_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None, time_embedding_dim: Optional[int] = None, + time_cond_proj_dim: Optional[int] = None, # additional controlnet configs time_embedding_mix: float = 1.0, ctrl_conditioning_channels: int = 3, @@ -528,6 +531,7 @@ def __init__( time_embedding_dim=time_embedding_dim, class_embed_type=class_embed_type, addition_embed_type=addition_embed_type, + time_cond_proj_dim=time_cond_proj_dim, ) self.in_channels = 4 @@ -535,8 +539,8 @@ def __init__( self.base_time_proj = base_model.time_proj self.base_time_embedding = base_model.time_embedding self.base_class_embedding = base_model.class_embedding - self.base_add_time_proj = base_model.add_time_proj if hasattr(base_model, 'add_time_proj') else None - self.base_add_embedding = base_model.add_embedding if hasattr(base_model, 'add_embedding') else None + self.base_add_time_proj = base_model.add_time_proj if hasattr(base_model, "add_time_proj") else None + self.base_add_embedding = base_model.add_embedding if hasattr(base_model, "add_embedding") else None self.base_conv_in = base_model.conv_in self.base_mid_block = base_model.mid_block @@ -576,7 +580,7 @@ def _unet_to_subblocks(cls, unet: UNet2DConditionModel): resnets = block.resnets attentions = block.attentions if hasattr(block, "attentions") else [None] * len(resnets) for r, a in zip(resnets, attentions): - down_subblocks.append(CrossAttnSubBlock2D.from_modules(r, a)) + down_subblocks.append(CrossAttnDownSubBlock2D.from_modules(r, a)) # Each Downsampler is a subblock if block.downsamplers is not None: if len(block.downsamplers) != 1: @@ -631,8 +635,9 @@ def from_unet2d( "upcast_attention", "class_embed_type", "addition_embed_type", + "time_cond_proj_dim", ] - config.update({k:v for k,v in unet.config.items() if k in params_for_unet}) + config.update({k: v for k, v in unet.config.items() if k in params_for_unet}) # The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. config["num_attention_heads"] = unet.config.attention_head_dim @@ -644,9 +649,9 @@ def from_unet2d( "learn_time_embedding", "block_out_channels", "attention_head_dim", - "max_norm_num_groups" + "max_norm_num_groups", ] - config.update({"ctrl_"+k:v for k,v in controlnet.config.items() if k in params_for_controlnet}) + config.update({"ctrl_" + k: v for k, v in controlnet.config.items() if k in params_for_controlnet}) model = cls.from_config(config) @@ -661,7 +666,7 @@ def from_unet2d( "mid_block", "conv_norm_out", "conv_act", - "conv_out" + "conv_out", ] for m in modules_from_unet: getattr(model, "base_" + m).load_state_dict(getattr(unet, m).state_dict()) @@ -709,13 +714,17 @@ def _check_if_vae_compatible(self, vae: AutoencoderKL): compatible = condition_downscale_factor == vae_downscale_factor return compatible, condition_downscale_factor, vae_downscale_factor + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, - controlnet_cond: torch.Tensor, - conditioning_scale: float = 1.0, + controlnet_cond: Optional[torch.Tensor] = None, + conditioning_scale: Optional[float] = 1.0, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, @@ -799,7 +808,7 @@ def forward( if self.config.ctrl_learn_time_embedding: ctrl_temb = self.control_addon.time_embedding(t_emb, timestep_cond) base_temb = self.base_time_embedding(t_emb, timestep_cond) - interpolation_param = self.control_addon.time_embedding_mix**0.3 + interpolation_param = self.control_addon.config.time_embedding_mix**0.3 temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param) else: @@ -848,8 +857,6 @@ def forward( cemb = encoder_hidden_states # Preparation - guided_hint = self.control_addon.controlnet_cond_embedding(controlnet_cond) - h_ctrl = h_base = sample hs_base, hs_ctrl = [], [] @@ -871,7 +878,7 @@ def forward( hs_base.append(h_base) for b in base_down_subblocks: - if isinstance(b, CrossAttnSubBlock2D): + if isinstance(b, CrossAttnDownSubBlock2D): additional_params = [temb, cemb, attention_mask, cross_attention_kwargs] else: additional_params = [] @@ -895,6 +902,8 @@ def forward( return ControlNetXSOutput(sample=h_base) + guided_hint = self.control_addon.controlnet_cond_embedding(controlnet_cond) + # 1 - conv in & down # The base -> ctrl connections are "delayed" by 1 subblock, because we want to "wait" to ensure the new information from the last ctrl -> base connection is also considered. # Therefore, the connections iterate over: @@ -916,7 +925,7 @@ def forward( down_zero_convs_b2c[:-1], down_zero_convs_c2b[1:], ): - if isinstance(b, CrossAttnSubBlock2D): + if isinstance(b, CrossAttnDownSubBlock2D): additional_params = [temb, cemb, attention_mask, cross_attention_kwargs] else: additional_params = [] @@ -931,9 +940,7 @@ def forward( h_ctrl = torch.cat([h_ctrl, down_zero_convs_b2c[-1](h_base)], dim=1) # concat base -> ctrl # 2 - mid - h_base = self.base_mid_block( - h_base, temb, cemb, attention_mask, cross_attention_kwargs - ) # apply base subblock + h_base = self.base_mid_block(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # apply base subblock h_ctrl = self.control_addon.mid_block( h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs ) # apply ctrl subblock @@ -975,7 +982,7 @@ def find_largest_factor(number, max_factor): factor -= 1 -class CrossAttnSubBlock2D(nn.Module): +class CrossAttnDownSubBlock2D(nn.Module): def __init__( self, is_empty: bool = False, diff --git a/tests/models/unets/test_models_unet_controlnetxs.py b/tests/models/unets/test_models_unet_controlnetxs.py new file mode 100644 index 000000000000..1c6c4382bf67 --- /dev/null +++ b/tests/models/unets/test_models_unet_controlnetxs.py @@ -0,0 +1,182 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import os +import re +import tempfile +import unittest + +import numpy as np +import torch + +from diffusers import ControlNetXSAddon, UNet2DConditionModel, UNetControlNetXSModel +from diffusers.utils import logging +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.testing_utils import ( + enable_full_determinism, + floats_tensor, + torch_device, +) + +from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin + + +logger = logging.get_logger(__name__) + +enable_full_determinism() + + +class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): + model_class = UNetControlNetXSModel + main_input_name = "sample" + + def get_dummy_components(self, seed=0): + torch.manual_seed(seed) + unet = UNet2DConditionModel( + block_out_channels=(4, 8), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + norm_num_groups=1, + use_linear_projection=True, + ) + controlnet = ControlNetXSAddon.from_unet(unet, size_ratio=1) + return unet, controlnet + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 4 + sizes = (32, 32) + + noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + time_step = torch.tensor([10]).to(torch_device) + encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device) + + return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} + + @property + def input_shape(self): + return (4, 8, 32, 32) + + def test_from_unet2d(self): + torch.manual_seed(0) + unet2d, controlnet = self.get_dummy_components() + + model = UNetControlNetXSModel.from_unet2d(unet2d, controlnet) + model_state_dict = model.state_dict() + + def is_decomposed(module_name): + return "down_block" in module_name or "up_block" in module_name + + def block_to_subblock_name(param_name): + """ + Map name of a param from 'block notation' as in UNet to 'subblock notation' as in UNetControlNetXS + e.g. 'down_blocks.1.attentions.0.proj_in.weight' -> 'base_down_subblocks.3.attention.proj_in.weight' + """ + param_name = param_name.replace("down_blocks", "base_down_subblocks") + param_name = param_name.replace("up_blocks", "base_up_subblocks") + + numbers = re.findall(r"\d+", param_name) + block_idx, module_idx = int(numbers[0]), int(numbers[1]) + + layers_per_block = 2 + subblocks_per_block = layers_per_block + 1 # include down/upsampler + + if "downsampler" in param_name or "upsampler" in param_name: + subblock_idx = block_idx * subblocks_per_block + layers_per_block + else: + subblock_idx = block_idx * subblocks_per_block + module_idx + + param_name = re.sub(r"\d", str(subblock_idx), param_name, count=1) + param_name = re.sub(r"resnets\.\d+", "resnet", param_name) # eg resnets.1 -> resnet + param_name = re.sub(r"attentions\.\d+", "attention", param_name) # eg attentions.1 -> attention + param_name = re.sub(r"downsamplers\.\d+", "downsampler", param_name) # eg attentions.1 -> attention + param_name = re.sub(r"upsamplers\.\d+", "upsampler", param_name) # eg attentions.1 -> attention + + return param_name + + for param_name, param_value in unet2d.named_parameters(): + if is_decomposed(param_name): + # check unet modules that were decomposed + self.assertTrue(torch.equal(model_state_dict[block_to_subblock_name(param_name)], param_value)) + else: + # check unet modules that were copied as is + self.assertTrue(torch.equal(model_state_dict["base_" + param_name], param_value)) + + # check controlnet + for param_name, param_value in controlnet.named_parameters(): + self.assertTrue(torch.equal(model_state_dict["control_addon." + param_name], param_value)) + + def test_freeze_unet2d(self): + model = UNetControlNetXSModel.from_unet2d(*self.get_dummy_components()) + model.freeze_unet2d_params() + + for param_name, param_value in model.named_parameters(): + if "control_addon" not in param_name: + self.assertFalse(param_value.requires_grad) + else: + self.assertTrue(param_value.requires_grad) + + def test_no_control(self): + unet2d, controlnet = self.get_dummy_components() + + model = UNetControlNetXSModel.from_unet2d(unet2d, controlnet) + + unet2d = unet2d.to(torch_device) + model = model.to(torch_device) + + input_ = self.dummy_input + with torch.no_grad(): + unet_output = unet2d(**input_).sample.cpu() + unet_controlnet_output = model(**input_, do_control=False).sample.cpu() + + assert np.abs(unet_output.flatten() - unet_controlnet_output.flatten()).max() < 1e-5 + + def test_gradient_checkpointing_is_applied(self): + model_class_copy = copy.copy(UNetControlNetXSModel) + + modules_with_gc_enabled = {} + + # now monkey patch the following function: + # def _set_gradient_checkpointing(self, module, value=False): + # if hasattr(module, "gradient_checkpointing"): + # module.gradient_checkpointing = value + + def _set_gradient_checkpointing_new(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + modules_with_gc_enabled[module.__class__.__name__] = True + + model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new + + model = model_class_copy.from_unet2d(*self.get_dummy_components()) + model.enable_gradient_checkpointing() + + EXPECTED_SET = { + "Transformer2DModel", + "UNetMidBlock2DCrossAttn", + "CrossAttnDownSubBlock2D", + "DownSubBlock2D", + "CrossAttnUpSubBlock2D" + } + + assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET + assert all(modules_with_gc_enabled.values()), "All modules should be enabled" diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py index d2614e023759..eeebe544d8d4 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py @@ -22,7 +22,10 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from diffusers import ( + AsymmetricAutoencoderKL, AutoencoderKL, + AutoencoderTiny, + ConsistencyDecoderVAE, ControlNetXSAddon, DDIMScheduler, LCMScheduler, @@ -43,6 +46,12 @@ ) from diffusers.utils.torch_utils import randn_tensor +from ...models.autoencoders.test_models_vae import ( + get_asym_autoencoder_kl_config, + get_autoencoder_kl_config, + get_autoencoder_tiny_config, + get_consistency_vae_config, +) from ..pipeline_params import ( IMAGE_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_BATCH_PARAMS, @@ -126,6 +135,7 @@ def get_dummy_components(self, time_cond_proj_dim=None): cross_attention_dim=32, norm_num_groups=1, time_cond_proj_dim=time_cond_proj_dim, + use_linear_projection=True, ) torch.manual_seed(0) controlnet = ControlNetXSAddon.from_unet( @@ -236,6 +246,53 @@ def test_controlnet_lcm(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_to_dtype(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + # pipeline creates a new UNetControlNetXSModel under the hood. So we need to check the dtype from pipe.components + model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] + self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) + + pipe.to(dtype=torch.float16) + model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] + self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) + + def test_multi_vae(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + block_out_channels = pipe.vae.config.block_out_channels + norm_num_groups = pipe.vae.config.norm_num_groups + + vae_classes = [AutoencoderKL, AsymmetricAutoencoderKL, ConsistencyDecoderVAE, AutoencoderTiny] + configs = [ + get_autoencoder_kl_config(block_out_channels, norm_num_groups), + get_asym_autoencoder_kl_config(block_out_channels, norm_num_groups), + get_consistency_vae_config(block_out_channels, norm_num_groups), + get_autoencoder_tiny_config(block_out_channels), + ] + + out_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0] + + for vae_cls, config in zip(vae_classes, configs): + vae = vae_cls(**config) + vae = vae.to(torch_device) + components["vae"] = vae + vae_pipe = self.pipeline_class(**components) + + # pipeline creates a new UNetControlNetXSModel under the hood, which aren't on device. + # So we need to move the new pipe to device. + vae_pipe.to(torch_device) + vae_pipe.set_progress_bar_config(disable=None) + + out_vae_np = vae_pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0] + + assert out_vae_np.shape == out_np.shape + @slow @require_torch_gpu diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py index e08db91d695f..463ee509d90f 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py @@ -14,26 +14,32 @@ # limitations under the License. import gc -import tempfile import unittest import numpy as np import torch from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer -import diffusers from diffusers import ( + AsymmetricAutoencoderKL, AutoencoderKL, + AutoencoderTiny, + ConsistencyDecoderVAE, ControlNetXSAddon, EulerDiscreteScheduler, StableDiffusionXLControlNetXSPipeline, UNet2DConditionModel, ) -from diffusers.utils import logging from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import enable_full_determinism, load_image, require_torch_gpu, slow, torch_device from diffusers.utils.torch_utils import randn_tensor +from ...models.autoencoders.test_models_vae import ( + get_asym_autoencoder_kl_config, + get_autoencoder_kl_config, + get_autoencoder_tiny_config, + get_consistency_vae_config, +) from ..pipeline_params import ( IMAGE_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_BATCH_PARAMS, @@ -45,7 +51,6 @@ PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, - to_np, ) @@ -77,9 +82,9 @@ def get_dummy_components(self): out_channels=4, down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + use_linear_projection=True, # SD2-specific config below attention_head_dim=(2, 4), - use_linear_projection=True, addition_embed_type="text_time", addition_time_embed_dim=8, transformer_layers_per_block=(1, 2), @@ -308,6 +313,54 @@ def test_stable_diffusion_xl_prompt_embeds(self): # make sure that it's equal assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1.1e-4 + # copied from test_controlnetxs.py + def test_to_dtype(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + # pipeline creates a new UNetControlNetXSModel under the hood. So we need to check the dtype from pipe.components + model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] + self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) + + pipe.to(dtype=torch.float16) + model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] + self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) + + def test_multi_vae(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + block_out_channels = pipe.vae.config.block_out_channels + norm_num_groups = pipe.vae.config.norm_num_groups + + vae_classes = [AutoencoderKL, AsymmetricAutoencoderKL, ConsistencyDecoderVAE, AutoencoderTiny] + configs = [ + get_autoencoder_kl_config(block_out_channels, norm_num_groups), + get_asym_autoencoder_kl_config(block_out_channels, norm_num_groups), + get_consistency_vae_config(block_out_channels, norm_num_groups), + get_autoencoder_tiny_config(block_out_channels), + ] + + out_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0] + + for vae_cls, config in zip(vae_classes, configs): + vae = vae_cls(**config) + vae = vae.to(torch_device) + components["vae"] = vae + vae_pipe = self.pipeline_class(**components) + + # pipeline creates a new UNetControlNetXSModel under the hood, which aren't on device. + # So we need to move the new pipe to device. + vae_pipe.to(torch_device) + vae_pipe.set_progress_bar_config(disable=None) + + out_vae_np = vae_pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0] + + assert out_vae_np.shape == out_np.shape + @slow @require_torch_gpu From 90a6a5096bb0c4cd7a35bdba3a31cbe32c038928 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Thu, 14 Mar 2024 03:27:53 +0100 Subject: [PATCH 50/75] Added tests for UNetControlNetXSModel --- src/diffusers/models/controlnet_xs.py | 6 +- src/diffusers/utils/dummy_pt_objects.py | 28 ++--- .../unets/test_models_unet_controlnetxs.py | 104 ++++++++++++------ 3 files changed, 86 insertions(+), 52 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 608e98dd4847..15536ff234a2 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -37,11 +37,11 @@ @dataclass class ControlNetXSOutput(BaseOutput): """ - The output of [`ControlNetXSModel`]. + The output of [`UNetControlNetXSModel`]. Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - The output of the `ControlNetXSModel`. Unlike `ControlNetOutput` this is NOT to be added to the base model + The output of the `UNetControlNetXSModel`. Unlike `ControlNetOutput` this is NOT to be added to the base model output, but is already the final output. """ @@ -960,7 +960,7 @@ def forward( h_base = self.base_conv_out(h_base) if not return_dict: - return h_base + return (h_base,) return ControlNetXSOutput(sample=h_base) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index a13fd90795c6..24e9dbe49d04 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -107,7 +107,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class ControlNetXSModel(metaclass=DummyObject): +class I2VGenXLUNet(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -122,7 +122,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class I2VGenXLUNet(metaclass=DummyObject): +class Kandinsky3UNet(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -137,7 +137,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class Kandinsky3UNet(metaclass=DummyObject): +class ModelMixin(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -152,7 +152,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class ModelMixin(metaclass=DummyObject): +class MotionAdapter(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -167,7 +167,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class MotionAdapter(metaclass=DummyObject): +class MultiAdapter(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -182,7 +182,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class MultiAdapter(metaclass=DummyObject): +class PriorTransformer(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -197,7 +197,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class PriorTransformer(metaclass=DummyObject): +class T2IAdapter(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -212,7 +212,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class T2IAdapter(metaclass=DummyObject): +class T5FilmDecoder(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -227,7 +227,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class T5FilmDecoder(metaclass=DummyObject): +class Transformer2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -242,7 +242,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class Transformer2DModel(metaclass=DummyObject): +class UNet1DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -257,7 +257,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class UNet1DModel(metaclass=DummyObject): +class UNet2DConditionModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -272,7 +272,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class UNet2DConditionModel(metaclass=DummyObject): +class UNet2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -287,7 +287,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class UNet2DModel(metaclass=DummyObject): +class UNet3DConditionModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -302,7 +302,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class UNet3DConditionModel(metaclass=DummyObject): +class UNetControlNetXSModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): diff --git a/tests/models/unets/test_models_unet_controlnetxs.py b/tests/models/unets/test_models_unet_controlnetxs.py index 1c6c4382bf67..df5a592e5a34 100644 --- a/tests/models/unets/test_models_unet_controlnetxs.py +++ b/tests/models/unets/test_models_unet_controlnetxs.py @@ -14,9 +14,7 @@ # limitations under the License. import copy -import os import re -import tempfile import unittest import numpy as np @@ -24,7 +22,6 @@ from diffusers import ControlNetXSAddon, UNet2DConditionModel, UNetControlNetXSModel from diffusers.utils import logging -from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( enable_full_determinism, floats_tensor, @@ -43,23 +40,6 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes model_class = UNetControlNetXSModel main_input_name = "sample" - def get_dummy_components(self, seed=0): - torch.manual_seed(seed) - unet = UNet2DConditionModel( - block_out_channels=(4, 8), - layers_per_block=2, - sample_size=32, - in_channels=4, - out_channels=4, - down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), - up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), - cross_attention_dim=32, - norm_num_groups=1, - use_linear_projection=True, - ) - controlnet = ControlNetXSAddon.from_unet(unet, size_ratio=1) - return unet, controlnet - @property def dummy_input(self): batch_size = 4 @@ -69,18 +49,64 @@ def dummy_input(self): noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) time_step = torch.tensor([10]).to(torch_device) encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device) - - return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} + controlnet_cond = floats_tensor((batch_size, 3, 256, 256)).to(torch_device) + conditioning_scale = 1 + + return { + "sample": noise, + "timestep": time_step, + "encoder_hidden_states": encoder_hidden_states, + "controlnet_cond": controlnet_cond, + "conditioning_scale": conditioning_scale, + } @property def input_shape(self): - return (4, 8, 32, 32) + return (4, 32, 32) + + @property + def output_shape(self): + return (4, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "sample_size": 32, + "down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"), + "up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"), + "block_out_channels": (4, 8), + "norm_num_groups": 1, + "cross_attention_dim": 32, + "transformer_layers_per_block": 1, + "num_attention_heads": 8, + "upcast_attention": False, + "ctrl_time_embedding_input_dim": 4, + "ctrl_block_out_channels": [4, 8], + "ctrl_attention_head_dim": 8, + "ctrl_max_norm_num_groups": 1, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def get_dummy_unet(self): + """For some tests we also need the underlying UNet. For these, we'll build the UNetControlNetXSModel from the UNet""" + return UNet2DConditionModel( + block_out_channels=(4, 8), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + norm_num_groups=1, + use_linear_projection=True, + ) def test_from_unet2d(self): - torch.manual_seed(0) - unet2d, controlnet = self.get_dummy_components() + unet = self.get_dummy_unet() + controlnet = ControlNetXSAddon.from_unet(unet, size_ratio=1) - model = UNetControlNetXSModel.from_unet2d(unet2d, controlnet) + model = UNetControlNetXSModel.from_unet2d(unet, controlnet) model_state_dict = model.state_dict() def is_decomposed(module_name): @@ -113,7 +139,7 @@ def block_to_subblock_name(param_name): return param_name - for param_name, param_value in unet2d.named_parameters(): + for param_name, param_value in unet.named_parameters(): if is_decomposed(param_name): # check unet modules that were decomposed self.assertTrue(torch.equal(model_state_dict[block_to_subblock_name(param_name)], param_value)) @@ -126,7 +152,8 @@ def block_to_subblock_name(param_name): self.assertTrue(torch.equal(model_state_dict["control_addon." + param_name], param_value)) def test_freeze_unet2d(self): - model = UNetControlNetXSModel.from_unet2d(*self.get_dummy_components()) + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = UNetControlNetXSModel(**init_dict) model.freeze_unet2d_params() for param_name, param_value in model.named_parameters(): @@ -135,17 +162,22 @@ def test_freeze_unet2d(self): else: self.assertTrue(param_value.requires_grad) - def test_no_control(self): - unet2d, controlnet = self.get_dummy_components() + def test_forward_no_control(self): + unet = self.get_dummy_unet() + controlnet = ControlNetXSAddon.from_unet(unet, size_ratio=1) - model = UNetControlNetXSModel.from_unet2d(unet2d, controlnet) + model = UNetControlNetXSModel.from_unet2d(unet, controlnet) - unet2d = unet2d.to(torch_device) + unet = unet.to(torch_device) model = model.to(torch_device) input_ = self.dummy_input + + control_specific_input = ["controlnet_cond", "conditioning_scale"] + input_for_unet = {k: v for k, v in input_.items() if k not in control_specific_input} + with torch.no_grad(): - unet_output = unet2d(**input_).sample.cpu() + unet_output = unet(**input_for_unet).sample.cpu() unet_controlnet_output = model(**input_, do_control=False).sample.cpu() assert np.abs(unet_output.flatten() - unet_controlnet_output.flatten()).max() < 1e-5 @@ -167,7 +199,9 @@ def _set_gradient_checkpointing_new(self, module, value=False): model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new - model = model_class_copy.from_unet2d(*self.get_dummy_components()) + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = model_class_copy(**init_dict) + model.enable_gradient_checkpointing() EXPECTED_SET = { @@ -175,7 +209,7 @@ def _set_gradient_checkpointing_new(self, module, value=False): "UNetMidBlock2DCrossAttn", "CrossAttnDownSubBlock2D", "DownSubBlock2D", - "CrossAttnUpSubBlock2D" + "CrossAttnUpSubBlock2D", } assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET From ef651275c2659bca7363cb31d676b5b73ad5fdbd Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Thu, 14 Mar 2024 04:44:53 +0100 Subject: [PATCH 51/75] Fixed SDXL tests --- src/diffusers/models/controlnet_xs.py | 6 ++++++ .../pipelines/controlnet_xs/pipeline_controlnet_xs.py | 2 +- .../controlnet_xs/pipeline_controlnet_xs_sd_xl.py | 8 ++++---- tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py | 4 ++++ 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 15536ff234a2..351cc0721008 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -492,8 +492,10 @@ def __init__( upcast_attention: bool = True, class_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, time_embedding_dim: Optional[int] = None, time_cond_proj_dim: Optional[int] = None, + projection_class_embeddings_input_dim: Optional[int] = None, # additional controlnet configs time_embedding_mix: float = 1.0, ctrl_conditioning_channels: int = 3, @@ -532,6 +534,8 @@ def __init__( class_embed_type=class_embed_type, addition_embed_type=addition_embed_type, time_cond_proj_dim=time_cond_proj_dim, + projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, + addition_time_embed_dim=addition_time_embed_dim, ) self.in_channels = 4 @@ -636,6 +640,8 @@ def from_unet2d( "class_embed_type", "addition_embed_type", "time_cond_proj_dim", + "projection_class_embeddings_input_dim", + "addition_time_embed_dim", ] config.update({k: v for k, v in unet.config.items() if k in params_for_unet}) # The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index 5aabc578392c..7e62dad0500d 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -125,7 +125,7 @@ class StableDiffusionControlNetXSPipeline( A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ - # todo: dont load controlnet to gpu + # todo umer: dont load controlnet to gpu, its already part of unet model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index b0b87c718acf..295106760f78 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -152,7 +152,7 @@ class StableDiffusionXLControlNetXSPipeline( watermarker is used. """ - # todo: dont load controlnet to gpu + # todo umer: dont load controlnet to gpu, its already part of unet model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" _optional_components = [ "tokenizer", @@ -260,6 +260,7 @@ def disable_vae_tiling(self): """ self.vae.disable_tiling() + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( self, prompt: str, @@ -318,7 +319,6 @@ def encode_prompt( Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ - # Note: this is almost an exact copy of `StableDiffusionXLPipeline.encode_prompt` except that `sefl.controlnet` is used instead of `self.unet` device = device or self._execution_device @@ -357,7 +357,7 @@ def encode_prompt( prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - # textual inversion: procecss multi-vector tokens if necessary + # textual inversion: process multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): @@ -719,7 +719,7 @@ def _get_add_time_ids( passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) - expected_add_embed_dim = self.base_add_embedding.linear_1.in_features + expected_add_embed_dim = self.unet.base_add_embedding.linear_1.in_features if expected_add_embed_dim != passed_add_embed_dim: raise ValueError( diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py index 463ee509d90f..7295a7006eac 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py @@ -313,6 +313,10 @@ def test_stable_diffusion_xl_prompt_embeds(self): # make sure that it's equal assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1.1e-4 + # copied from test_stable_diffusion_xl.py + def test_save_load_optional_components(self): + self._test_save_load_optional_components() + # copied from test_controlnetxs.py def test_to_dtype(self): components = self.get_dummy_components() From f68477b0f22e2e4e70abcc89c91d4cd288248958 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Thu, 14 Mar 2024 06:10:48 +0100 Subject: [PATCH 52/75] cleanup --- src/diffusers/models/controlnet_xs.py | 54 ++++++++++------- .../controlnet_xs/pipeline_controlnet_xs.py | 13 ++-- .../pipeline_controlnet_xs_sd_xl.py | 17 +++--- .../unets/test_models_unet_controlnetxs.py | 60 ++++++++++++------- 4 files changed, 85 insertions(+), 59 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 351cc0721008..c2abb24d9114 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -24,7 +24,6 @@ from .autoencoders import AutoencoderKL from .embeddings import ( TimestepEmbedding, - Timesteps, ) from .modeling_utils import ModelMixin from .unets.unet_2d_blocks import Downsample2D, ResnetBlock2D, Transformer2DModel, UNetMidBlock2DCrossAttn, Upsample2D @@ -115,8 +114,10 @@ class ControlNetXSAddon(ModelMixin, ConfigMixin): Dimension of input into time embedding. Needs to be same as in the base model. time_embedding_dim (`int`, defaults to 1280): Dimension of output from time embedding. Needs to be same as in the base model. - time_embedding_mix - # todo umer + time_embedding_mix (`float`, defaults to 1.0): + If 0, then only the control addon's time embedding is used. + If 1, then only the base unet's time embedding is used. + Otherwise, both are combined. learn_time_embedding (`bool`, defaults to `False`): Whether a time embedding should be learned. If yes, `ControlNetXSModel` will combine the time embeddings of the base model and the addon. If no, `ControlNetXSModel` will use the base model's time embedding. @@ -201,6 +202,7 @@ def from_unet( block_out_channels: Optional[List[int]] = None, num_attention_heads: Optional[List[int]] = None, learn_time_embedding: bool = False, + time_embedding_mix: int = 1.0, conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), ): r""" @@ -256,6 +258,7 @@ def from_unet( conditioning_embedding_out_channels=conditioning_embedding_out_channels, time_embedding_input_dim=time_embedding_input_dim, time_embedding_dim=time_embedding_dim, + time_embedding_mix=time_embedding_mix, ) @register_to_config @@ -324,10 +327,8 @@ def __init__( # time if learn_time_embedding: - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0) self.time_embedding = TimestepEmbedding(time_embedding_input_dim, time_embedding_dim) else: - self.time_proj = None self.time_embedding = None self.time_embed_act = None @@ -454,20 +455,35 @@ def _make_zero_conv(self, in_channels, out_channels=None): class UNetControlNetXSModel(ModelMixin, ConfigMixin): r""" - A ControlNet-XS model + A UNet fused with a ControlNet-XS addon model This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). - `ControlNetXSModel` is compatible with StableDiffusion and StableDiffusion-XL. + `UNetControlNetXSModel` is compatible with StableDiffusion and StableDiffusion-XL. It's default parameters are compatible with StableDiffusion. + Most of it's paremeters are passed to the underlying `UNet2DConditionModel`. See it's documentation for details. + Parameters: - # todo umer time_embedding_mix (`float`, defaults to 1.0): - If 0, then only the base model's time embedding is used. - If 1, then only the control model's time embedding is used. + If 0, then only the control addon's time embedding is used. + If 1, then only the base unet's time embedding is used. Otherwise, both are combined. + ctrl_conditioning_channels (`int`, defaults to 3): + The number of channels of the control conditioning input. + ctrl_conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): + Block sizes of the `ControlNetConditioningEmbedding`. + ctrl_conditioning_channel_order (`str`, defaults to "rgb"): + The order of channels in the control conditioning input. + ctrl_learn_time_embedding (`bool`, defaults to False): + Whether the control addon should learn a time embedding. Needs to be `True` if `time_embedding_mix` > 0. + ctrl_block_out_channels (`tuple[int]`, defaults to `(4, 8, 16, 16)`): + The tuple of output channels for each block in the control addon. + ctrl_attention_head_dim (`int` or `tuple[int]`, defaults to 4): + The dimension of the attention heads in the control addon. + ctrl_max_norm_num_groups (`int`, defaults to 32): + The maximum number of groups to use for the normalization in the control addon. Can be reduced to fit the block sizes. """ _supports_gradient_checkpointing = True @@ -487,12 +503,12 @@ def __init__( block_out_channels: Tuple[int] = (320, 640, 1280, 1280), norm_num_groups: Optional[int] = 32, cross_attention_dim: Union[int, Tuple[int]] = 1024, - transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, # type Tuple[Tuple] necessary? + transformer_layers_per_block: Union[int, Tuple[int]] = 1, num_attention_heads: Optional[Union[int, Tuple[int]]] = 8, - upcast_attention: bool = True, class_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, + upcast_attention: bool = True, time_embedding_dim: Optional[int] = None, time_cond_proj_dim: Optional[int] = None, projection_class_embeddings_input_dim: Optional[int] = None, @@ -500,7 +516,6 @@ def __init__( time_embedding_mix: float = 1.0, ctrl_conditioning_channels: int = 3, ctrl_conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), - ctrl_time_embedding_input_dim: int = 320, ctrl_conditioning_channel_order: str = "rgb", ctrl_learn_time_embedding: bool = False, ctrl_block_out_channels: Tuple[int] = (4, 8, 16, 16), @@ -558,7 +573,7 @@ def __init__( conditioning_channels=ctrl_conditioning_channels, conditioning_channel_order=ctrl_conditioning_channel_order, conditioning_embedding_out_channels=ctrl_conditioning_embedding_out_channels, - time_embedding_input_dim=ctrl_time_embedding_input_dim, + time_embedding_input_dim=block_out_channels[0], time_embedding_dim=time_embedding_dim, time_embedding_mix=time_embedding_mix, learn_time_embedding=ctrl_learn_time_embedding, @@ -575,7 +590,7 @@ def __init__( @classmethod def _unet_to_subblocks(cls, unet: UNet2DConditionModel): - """todo umer""" + """Decompose the down and up blocks of a UNet into subblocks, as required by UNetControlNetXSModel""" down_subblocks = nn.ModuleList() up_subblocks = nn.ModuleList() @@ -621,14 +636,11 @@ def from_unet2d( controlnet: ControlNetXSAddon, load_weights: bool = True, ): - # todo umer: assert unet is sd/sdxl? - # Create config for UNetControlNetXSModel object config = {} config["_class_name"] = cls.__name__ params_for_unet = [ - "time_embedding_dim", "sample_size", "down_block_types", "up_block_types", @@ -636,12 +648,13 @@ def from_unet2d( "norm_num_groups", "cross_attention_dim", "transformer_layers_per_block", - "upcast_attention", "class_embed_type", "addition_embed_type", + "addition_time_embed_dim", + "upcast_attention", + "time_embedding_dim", "time_cond_proj_dim", "projection_class_embeddings_input_dim", - "addition_time_embed_dim", ] config.update({k: v for k, v in unet.config.items() if k in params_for_unet}) # The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. @@ -651,7 +664,6 @@ def from_unet2d( "conditioning_channels", "conditioning_embedding_out_channels", "conditioning_channel_order", - "time_embedding_input_dim", "learn_time_embedding", "block_out_channels", "attention_head_dim", diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index 7e62dad0500d..90b91f033cdd 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -67,12 +67,11 @@ >>> # initialize the models and pipeline >>> controlnet_conditioning_scale = 0.5 - >>> controlnet_xs_addon = ControlNetXSAddon.from_pretrained( + >>> controlnet = ControlNetXSAddon.from_pretrained( ... "UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16 ... ) >>> pipe = StableDiffusionControlNetXSPipeline.from_pretrained( - ... "stabilityai/stable-diffusion-2-1-base", controlnet_xs_addon=controlnet_xs_addon, - ... time_embedding_mix=1.0, torch_dtype=torch.float16 + ... "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16 ... ) # paper used time_embedding_mix=1.0 >>> pipe.enable_model_cpu_offload() @@ -125,7 +124,6 @@ class StableDiffusionControlNetXSPipeline( A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ - # todo umer: dont load controlnet to gpu, its already part of unet model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] @@ -815,8 +813,7 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) - # todo umer: what's this for? - controlnet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet + unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet # 1. Check inputs. Raise error if not correct self.check_inputs( @@ -874,7 +871,7 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare image - if isinstance(controlnet, UNetControlNetXSModel): + if isinstance(unet, UNetControlNetXSModel): image = self.prepare_image( image=image, width=width, @@ -882,7 +879,7 @@ def __call__( batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, - dtype=controlnet.dtype, + dtype=unet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, ) height, width = image.shape[-2:] diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index 295106760f78..cf826d96ed05 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -83,13 +83,12 @@ >>> # initialize the models and pipeline >>> controlnet_conditioning_scale = 0.5 # recommended for good generalization >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) - >>> controlnet_xs_addon = ControlNetXSAddon.from_pretrained( + >>> controlnet = ControlNetXSAddon.from_pretrained( ... "UmerHA/Testing-ConrolNetXS-SDXL-canny", torch_dtype=torch.float16 ... ) - >>> pipe = StableDiffusionControlNetXSPipeline.from_pretrained( - ... base_path="stabilityai/stable-diffusion-xl-base-1.0", controlnet_xs_addon=controlnet_xs_addon, - ... time_embedding_mix=0.95, torch_dtype=torch.float16 - ... ) # paper used time_embedding_mix=0.95 + >>> pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, , torch_dtype=torch.float16 + ... ) >>> pipe.enable_model_cpu_offload() >>> # get canny image @@ -152,7 +151,6 @@ class StableDiffusionXLControlNetXSPipeline( watermarker is used. """ - # todo umer: dont load controlnet to gpu, its already part of unet model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" _optional_components = [ "tokenizer", @@ -319,7 +317,6 @@ def encode_prompt( Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ - device = device or self._execution_device # set lora scale so that monkey patched LoRA @@ -955,7 +952,7 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) - controlnet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet + unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet # 1. Check inputs. Raise error if not correct self.check_inputs( @@ -1020,7 +1017,7 @@ def __call__( ) # 4. Prepare image - if isinstance(controlnet, UNetControlNetXSModel): + if isinstance(unet, UNetControlNetXSModel): image = self.prepare_image( image=image, width=width, @@ -1028,7 +1025,7 @@ def __call__( batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, - dtype=controlnet.dtype, + dtype=unet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, ) height, width = image.shape[-2:] diff --git a/tests/models/unets/test_models_unet_controlnetxs.py b/tests/models/unets/test_models_unet_controlnetxs.py index df5a592e5a34..2bdcd7aef42b 100644 --- a/tests/models/unets/test_models_unet_controlnetxs.py +++ b/tests/models/unets/test_models_unet_controlnetxs.py @@ -162,26 +162,6 @@ def test_freeze_unet2d(self): else: self.assertTrue(param_value.requires_grad) - def test_forward_no_control(self): - unet = self.get_dummy_unet() - controlnet = ControlNetXSAddon.from_unet(unet, size_ratio=1) - - model = UNetControlNetXSModel.from_unet2d(unet, controlnet) - - unet = unet.to(torch_device) - model = model.to(torch_device) - - input_ = self.dummy_input - - control_specific_input = ["controlnet_cond", "conditioning_scale"] - input_for_unet = {k: v for k, v in input_.items() if k not in control_specific_input} - - with torch.no_grad(): - unet_output = unet(**input_for_unet).sample.cpu() - unet_controlnet_output = model(**input_, do_control=False).sample.cpu() - - assert np.abs(unet_output.flatten() - unet_controlnet_output.flatten()).max() < 1e-5 - def test_gradient_checkpointing_is_applied(self): model_class_copy = copy.copy(UNetControlNetXSModel) @@ -214,3 +194,43 @@ def _set_gradient_checkpointing_new(self, module, value=False): assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET assert all(modules_with_gc_enabled.values()), "All modules should be enabled" + + def test_forward_no_control(self): + unet = self.get_dummy_unet() + controlnet = ControlNetXSAddon.from_unet(unet, size_ratio=1) + + model = UNetControlNetXSModel.from_unet2d(unet, controlnet) + + unet = unet.to(torch_device) + model = model.to(torch_device) + + input_ = self.dummy_input + + control_specific_input = ["controlnet_cond", "conditioning_scale"] + input_for_unet = {k: v for k, v in input_.items() if k not in control_specific_input} + + with torch.no_grad(): + unet_output = unet(**input_for_unet).sample.cpu() + unet_controlnet_output = model(**input_, do_control=False).sample.cpu() + + assert np.abs(unet_output.flatten() - unet_controlnet_output.flatten()).max() < 1e-5 + + def test_time_embedding_mixing(self): + unet = self.get_dummy_unet() + controlnet = ControlNetXSAddon.from_unet(unet, size_ratio=1) + controlnet_mix_time = ControlNetXSAddon.from_unet(unet, size_ratio=1, time_embedding_mix=0.5) + + model = UNetControlNetXSModel.from_unet2d(unet, controlnet) + model_mix_time = UNetControlNetXSModel.from_unet2d(unet, controlnet_mix_time) + + unet = unet.to(torch_device) + model = model.to(torch_device) + model_mix_time = model_mix_time.to(torch_device) + + input_ = self.dummy_input + + with torch.no_grad(): + output = model(**input_).sample + output_mix_time = model_mix_time(**input_).sample + + assert output.shape == output_mix_time.shape From 8521ed278c3669a66dd966d4953764a60a14cb61 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Thu, 14 Mar 2024 06:12:45 +0100 Subject: [PATCH 53/75] Delete Pipfile --- Pipfile | 11 ----------- 1 file changed, 11 deletions(-) delete mode 100644 Pipfile diff --git a/Pipfile b/Pipfile deleted file mode 100644 index 0757494bb360..000000000000 --- a/Pipfile +++ /dev/null @@ -1,11 +0,0 @@ -[[source]] -url = "https://pypi.org/simple" -verify_ssl = true -name = "pypi" - -[packages] - -[dev-packages] - -[requires] -python_version = "3.11" From ea69d3c4a5242dd627587f36589003f55a47100b Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Wed, 20 Mar 2024 23:19:55 +0100 Subject: [PATCH 54/75] CheckIn Mar 20 Started replacing sub blocks by `ControlNetXSCrossAttnDownBlock2D` and `ControlNetXSCrossAttnUplock2D` --- src/diffusers/models/controlnet_xs.py | 462 +++++++++++++------------- 1 file changed, 240 insertions(+), 222 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index c2abb24d9114..52cdf84f2071 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -880,97 +880,36 @@ def forward( # Cross Control # Let's first define variables to shorten notation - base_down_subblocks = self.base_down_subblocks - ctrl_down_subblocks = self.control_addon.down_subblocks - - down_zero_convs_b2c = self.control_addon.down_zero_convs_b2c - down_zero_convs_c2b = self.control_addon.down_zero_convs_c2b - mid_zero_convs_c2b = self.control_addon.mid_zero_convs_c2b - up_zero_convs_c2b = self.control_addon.up_zero_convs_c2b - - if not do_control: - # Run the base model without control - - # 1 - conv in & down - h_base = self.base_conv_in(h_base) - hs_base.append(h_base) - - for b in base_down_subblocks: - if isinstance(b, CrossAttnDownSubBlock2D): - additional_params = [temb, cemb, attention_mask, cross_attention_kwargs] - else: - additional_params = [] - h_base = b(h_base, *additional_params) - hs_base.append(h_base) - - # 2 - mid - h_base = self.base_mid_block(h_base, temb, cemb, attention_mask, cross_attention_kwargs) - - # 3 - up - for b, skip_b in zip(self.base_up_subblocks, reversed(hs_base)): - h_base = torch.cat([h_base, skip_b], dim=1) # concat info from base encoder - h_base = b(h_base, temb, cemb, attention_mask, cross_attention_kwargs) - - h_base = self.base_conv_norm_out(h_base) - h_base = self.base_conv_act(h_base) - h_base = self.base_conv_out(h_base) - - if not return_dict: - return h_base - - return ControlNetXSOutput(sample=h_base) guided_hint = self.control_addon.controlnet_cond_embedding(controlnet_cond) # 1 - conv in & down - # The base -> ctrl connections are "delayed" by 1 subblock, because we want to "wait" to ensure the new information from the last ctrl -> base connection is also considered. - # Therefore, the connections iterate over: - # ctrl -> base: conv_in | subblock 1 | ... | subblock n - # base -> ctrl: | subblock 1 | ... | subblock n | mid block h_base = self.base_conv_in(h_base) h_ctrl = self.control_addon.conv_in(h_ctrl) if guided_hint is not None: h_ctrl += guided_hint - h_base = h_base + down_zero_convs_c2b[0](h_ctrl) * conditioning_scale # add ctrl -> base + h_base = h_base + self.pre_zero_convs_c2b(h_ctrl) * conditioning_scale # add ctrl -> base # todo umer: define self.pre_zero_convs_c2b hs_base.append(h_base) hs_ctrl.append(h_ctrl) - for b, c, b2c, c2b in zip( - base_down_subblocks, - ctrl_down_subblocks, - down_zero_convs_b2c[:-1], - down_zero_convs_c2b[1:], - ): - if isinstance(b, CrossAttnDownSubBlock2D): - additional_params = [temb, cemb, attention_mask, cross_attention_kwargs] - else: - additional_params = [] - - h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) # concat base -> ctrl - h_base = b(h_base, *additional_params) # apply base subblock - h_ctrl = c(h_ctrl, *additional_params) # apply ctrl subblock - h_base = h_base + c2b(h_ctrl) * conditioning_scale # add ctrl -> base - - hs_base.append(h_base) - hs_ctrl.append(h_ctrl) - h_ctrl = torch.cat([h_ctrl, down_zero_convs_b2c[-1](h_base)], dim=1) # concat base -> ctrl + for down in self.down_blocks: # todo umer: define self.down_blocks + h_base,h_ctrl,residual_hb,residual_hc = down(h_base,h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) + hs_base.extend(residual_hb) + hs_ctrl.extend(residual_hc) # 2 - mid - h_base = self.base_mid_block(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # apply base subblock - h_ctrl = self.control_addon.mid_block( - h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs - ) # apply ctrl subblock - h_base = h_base + mid_zero_convs_c2b(h_ctrl) * conditioning_scale # add ctrl -> base + h_base,h_ctrl = self.mid_block(h_base,h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # todo umer: define self.mid_block # 3 - up - for b, c2b, skip_c, skip_b in zip( - self.base_up_subblocks, up_zero_convs_c2b, reversed(hs_ctrl), reversed(hs_base) - ): - h_base = h_base + c2b(skip_c) * conditioning_scale # add info from ctrl encoder - h_base = torch.cat([h_base, skip_b], dim=1) # concat info from base encoder+ctrl encoder - h_base = b(h_base, temb, cemb, attention_mask, cross_attention_kwargs) + for up in self.up_blocks: # todo umer: define self.up_blocks + n_resnets = len(up.resnets) + skips_hb = hs_base[-n_resnets:] + skips_hc = hs_ctrl[-n_resnets:] + hs_base = hs_base[:-n_resnets] + hs_ctrl = hs_ctrl[:-n_resnets] + h_base = up(h_base,h_ctrl,skips_hb,skips_hc,temb, cemb, attention_mask, cross_attention_kwargs) # 4 - conv out h_base = self.base_conv_norm_out(h_base) @@ -1000,7 +939,7 @@ def find_largest_factor(number, max_factor): factor -= 1 -class CrossAttnDownSubBlock2D(nn.Module): +class ControlNetXSCrossAttnDownBlock2D(nn.Module): def __init__( self, is_empty: bool = False, @@ -1008,48 +947,93 @@ def __init__( out_channels: Optional[int] = None, temb_channels: Optional[int] = None, max_norm_num_groups: Optional[int] = 32, - has_crossattn=False, - transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1, + has_crossattn=True, + transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple[int]]]] = 1, num_attention_heads: Optional[int] = 1, cross_attention_dim: Optional[int] = 1024, + add_downsample: bool = True, upcast_attention: Optional[bool] = False, ): super().__init__() - self.gradient_checkpointing = False + base_resnets = [] + base_attentions = [] + ctrl_resnets =[] + ctrl_attentions = [] - if is_empty: - # modules will be set manually, see `CrossAttnSubBlock2D.from_modules` - return + num_layers = 2 # only support sd + sdxl - self.in_channels = in_channels - self.out_channels = out_channels + self.has_cross_attention = has_crossattn + self.num_attention_heads = num_attention_heads + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + base_resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + ) + ) + ctrl_resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + groups=find_largest_factor(in_channels, max_factor=max_norm_num_groups), + groups_out=find_largest_factor(out_channels, max_factor=max_norm_num_groups), + eps=1e-5, + ) + ) - self.resnet = ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - groups=find_largest_factor(in_channels, max_factor=max_norm_num_groups), - groups_out=find_largest_factor(out_channels, max_factor=max_norm_num_groups), - eps=1e-5, - ) + if has_crossattn: + base_attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + use_linear_projection=True, + upcast_attention=upcast_attention, + ) + ) + ctrl_attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + use_linear_projection=True, + upcast_attention=upcast_attention, + norm_num_groups=find_largest_factor(out_channels, max_factor=max_norm_num_groups), + ) + ) - if has_crossattn: - self.attention = Transformer2DModel( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=transformer_layers_per_block, - cross_attention_dim=cross_attention_dim, - use_linear_projection=True, - upcast_attention=upcast_attention, - norm_num_groups=find_largest_factor(out_channels, max_factor=max_norm_num_groups), - ) + self.base_resnets = nn.ModuleList(base_resnets) + self.ctrl_resnets = nn.ModuleList(ctrl_resnets) + self.base_attentions = nn.ModuleList(base_attentions) if has_crossattn else [None]*num_layers + self.ctrl_attentions = nn.ModuleList(ctrl_attentions) if has_crossattn else [None]*num_layers + + if add_downsample: + self.base_downsamplers = Downsample2D(out_channels, use_conv=True, out_channels=out_channels, name="op") + self.ctrl_downsamplers = Downsample2D(out_channels, use_conv=True, out_channels=out_channels, name="op") else: - self.attention = None + self.base_downsamplers = None + self.ctrl_downsamplers = None + + # todo umer: connections b2c, c2b + self.b2c = None + self.c2b = None + + self.gradient_checkpointing = False @classmethod def from_modules(cls, resnet: ResnetBlock2D, attention: Optional[Transformer2DModel] = None): """Create empty subblock and set resnet and attention manually""" + # todo umer subblock = cls(is_empty=True) subblock.resnet = resnet subblock.attention = attention @@ -1059,100 +1043,144 @@ def from_modules(cls, resnet: ResnetBlock2D, attention: Optional[Transformer2DMo def forward( self, - hidden_states: torch.FloatTensor, + hidden_states_base: torch.FloatTensor, + hidden_states_ctrl: torch.FloatTensor, + conditioning_scale: Optional[float] = 1.0, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - if self.resnet is not None: - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) - if self.attention is not None: - hidden_states = self.attention( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - else: - if self.resnet is not None: - hidden_states = self.resnet(hidden_states, temb, scale=lora_scale) - if self.attention is not None: - hidden_states = self.attention( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: # todo umer: output type hint correct? + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - return hidden_states + h_base = hidden_states_base + h_ctrl = hidden_states_ctrl + + base_output_states = () + ctrl_output_states = () + base_blocks = list(zip(self.base_resnets, self.base_attentions)) + ctrl_blocks = list(zip(self.ctrl_resnets, self.ctrl_attentions)) -class DownSubBlock2D(nn.Module): + for (b_res, b_attn), (c_res, c_attn), b2c, c2b in zip(base_blocks, ctrl_blocks, self.b2c, self.c2b): + if self.training and self.gradient_checkpointing: + raise NotImplementedError("todo umer") + else: + # concat base -> ctrl + h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) + + # apply base subblock + h_base = b_res(h_base, temb) + if b_attn is not None: + h_base = b_attn( + h_base, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # apply ctrl subblock + h_ctrl = c_res(h_ctrl, temb) + if c_attn is not None: + h_ctrl = c_attn( + h_ctrl, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # add ctrl -> base + h_base = h_base + c2b(h_ctrl) * conditioning_scale + + base_output_states = base_output_states + (h_base,) + ctrl_output_states = ctrl_output_states + (h_ctrl,) + + if self.base_downsamplers is not None: # if we have a base_downsampler, then also a ctrl_downsampler + # concat base -> ctrl + h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) + # apply base subblock + h_base = self.base_downsamplers(h_base) + # apply ctrl subblock + h_ctrl = self.ctrl_downsamplers(h_ctrl) + # add ctrl -> base + h_base = h_base + c2b(h_ctrl) * conditioning_scale + + base_output_states = base_output_states + (h_base,) + ctrl_output_states = ctrl_output_states + (h_ctrl,) + + return h_base, h_ctrl,base_output_states, ctrl_output_states + + +class ControlNetXSCrossAttnUplock2D(nn.Module): def __init__( self, is_empty: bool = False, in_channels: Optional[int] = None, out_channels: Optional[int] = None, + prev_output_channel: Optional[int] = None, + temb_channels: Optional[int] = None, + has_crossattn=True, + transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple[int]]]] = 1, + num_attention_heads: Optional[int] = 1, + cross_attention_dim: Optional[int] = 1024, + add_upsample: bool = True, + upcast_attention: bool = False, ): super().__init__() - self.gradient_checkpointing = False + resnets = [] + attentions = [] - if is_empty: - # downsampler will be set manually, see `DownSubBlock2D.from_modules` - return + num_layers = 3 # only support sd + sdxl - self.in_channels = in_channels - self.out_channels = out_channels - self.downsampler = Downsample2D(in_channels, use_conv=True, out_channels=out_channels, name="op") + self.has_cross_attention = has_crossattn + self.num_attention_heads = num_attention_heads - @classmethod - def from_modules(cls, downsampler: Downsample2D): - """Create empty subblock and set downsampler manually""" - subblock = cls(is_empty=True) - subblock.downsampler = downsampler - subblock.in_channels = downsampler.channels - subblock.out_channels = downsampler.out_channels - return subblock + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers - def forward( - self, - hidden_states: torch.FloatTensor, - ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: - return self.downsampler(hidden_states) + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + ) + ) + + if has_crossattn: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + use_linear_projection=True, + upcast_attention=upcast_attention, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) if has_crossattn else [None]*num_layers + + if add_upsample: + self.upsamplers = Upsample2D(out_channels, use_conv=True, out_channels=out_channels) + else: + self.upsamplers = None + + # todo umer: c2b + self.c2b = None -class CrossAttnUpSubBlock2D(nn.Module): - def __init__(self): - """ - In the context of ControlNet-XS, `CrossAttnUpSubBlock2D` are only loaded from existing modules, and not created from scratch. - Therefore, `__init__` is left almost empty. - """ - super().__init__() self.gradient_checkpointing = False @classmethod @@ -1163,6 +1191,7 @@ def from_modules( upsampler: Optional[Upsample2D] = None, ): """Create empty subblock and set resnet, attention and upsampler manually""" + # todo umer subblock = cls() subblock.resnet = resnet subblock.attention = attention @@ -1174,55 +1203,44 @@ def from_modules( def forward( self, hidden_states: torch.FloatTensor, + res_hidden_states_tuple_base: Tuple[torch.FloatTensor, ...], + res_hidden_states_tuple_cltr: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) - if self.attention is not None: - hidden_states = self.attention( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - if self.upsampler is not None: - hidden_states = self.upsampler(hidden_states) - else: - hidden_states = self.resnet(hidden_states, temb, scale=lora_scale) - if self.attention is not None: - hidden_states = self.attention( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - if self.upsampler is not None: - hidden_states = self.upsampler(hidden_states) + ) -> torch.FloatTensor: # todo umer: output type hint correct? + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + for resnet, attn, c2b, res_h_base, res_h_ctrl in zip(self.resnets, self.attentions, self.c2b, reversed(res_hidden_states_tuple_base), reversed(res_hidden_states_tuple_cltr)): + hidden_states += c2b(res_h_ctrl) + hidden_states = torch.cat([hidden_states, res_h_base], dim=1) + + if self.training and self.gradient_checkpointing: + raise NotImplementedError("todo umer") + else: + hidden_states = resnet(hidden_states, temb) + if attn is not None: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if self.upsampler is not None: + c2b = self.c2b[-1] + res_h_base = res_hidden_states_tuple_base[0] + res_h_ctrl = res_hidden_states_tuple_cltr[0] + + hidden_states += c2b(res_h_ctrl) + hidden_states = torch.cat([hidden_states, res_h_base], dim=1) + hidden_states = self.upsampler(hidden_states, upsample_size) return hidden_states From f3f569ba5fb59e152f64304bd0e33bdd09fe1bd8 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Sat, 23 Mar 2024 16:21:38 +0100 Subject: [PATCH 55/75] check-in Mar 23 --- src/diffusers/models/controlnet_xs.py | 446 ++++++++++++++++++-------- 1 file changed, 312 insertions(+), 134 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 52cdf84f2071..a740240a6c86 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -20,7 +20,7 @@ from torch.nn import functional as F from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, is_torch_version, logging +from ..utils import BaseOutput, logging from .autoencoders import AutoencoderKL from .embeddings import ( TimestepEmbedding, @@ -422,20 +422,20 @@ def __init__( # 4.1 - Connections from base encoder to ctrl encoder # As the information is concatted to ctrl, the channels sizes don't change. for c in channels_base["down - out"]: - self.down_zero_convs_b2c.append(self._make_zero_conv(c, c)) + self.down_zero_convs_b2c.append(make_zero_conv(c, c)) # 4.2 - Connections from ctrl encoder to base encoder # As the information is added to base, the out-channels need to match base. for ch_base, ch_ctrl in zip(channels_base["down - out"], channels_ctrl["down - out"]): - self.down_zero_convs_c2b.append(self._make_zero_conv(ch_ctrl, ch_base)) + self.down_zero_convs_c2b.append(make_zero_conv(ch_ctrl, ch_base)) # 4.3 - Connections in mid block - self.mid_zero_convs_c2b = self._make_zero_conv(channels_ctrl["mid - out"], channels_base["mid - out"]) + self.mid_zero_convs_c2b = make_zero_conv(channels_ctrl["mid - out"], channels_base["mid - out"]) # 4.3 - Connections from ctrl encoder to base decoder skip_channels = reversed(channels_ctrl["down - out"]) for s, i in zip(skip_channels, channels_base["up - in"]): - self.up_zero_convs_c2b.append(self._make_zero_conv(s, i)) + self.up_zero_convs_c2b.append(make_zero_conv(s, i)) # 5 - Create conditioning hint embedding self.controlnet_cond_embedding = ControlNetConditioningEmbedding( @@ -449,9 +449,6 @@ def forward(self, *args, **kwargs): "A ControlNetXSAddonModel cannot be run by itself. Pass it into a ControlNetXSModel model instead." ) - def _make_zero_conv(self, in_channels, out_channels=None): - return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0)) - class UNetControlNetXSModel(ModelMixin, ConfigMixin): r""" @@ -504,7 +501,7 @@ def __init__( norm_num_groups: Optional[int] = 32, cross_attention_dim: Union[int, Tuple[int]] = 1024, transformer_layers_per_block: Union[int, Tuple[int]] = 1, - num_attention_heads: Optional[Union[int, Tuple[int]]] = 8, + num_attention_heads: Union[int, Tuple[int]] = 8, class_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, @@ -531,6 +528,13 @@ def __init__( "To use `time_embedding_mix` < 1, initialize `ctrl_addon` with `learn_time_embedding = True`" ) + def repeat_if_not_list(value, repetitions): + return value if isinstance(value, (tuple, list)) else [value] * repetitions + + transformer_layers_per_block = repeat_if_not_list(transformer_layers_per_block, repetitions=len(down_block_types)) + cross_attention_dim = repeat_if_not_list(cross_attention_dim, repetitions=len(down_block_types)) + num_attention_heads = repeat_if_not_list(num_attention_heads, repetitions=len(down_block_types)) + time_embedding_dim = time_embedding_dim or block_out_channels[0] * 4 # Create UNet and decompose it into subblocks, which we then save @@ -567,67 +571,76 @@ def __init__( self.base_conv_act = base_model.conv_act self.base_conv_out = base_model.conv_out - self.base_down_subblocks, self.base_up_subblocks = UNetControlNetXSModel._unet_to_subblocks(base_model) - - self.control_addon = ControlNetXSAddon( - conditioning_channels=ctrl_conditioning_channels, - conditioning_channel_order=ctrl_conditioning_channel_order, - conditioning_embedding_out_channels=ctrl_conditioning_embedding_out_channels, - time_embedding_input_dim=block_out_channels[0], - time_embedding_dim=time_embedding_dim, - time_embedding_mix=time_embedding_mix, - learn_time_embedding=ctrl_learn_time_embedding, - channels_base=ControlNetXSAddon.gather_base_subblock_sizes(block_out_channels), - attention_head_dim=ctrl_attention_head_dim, - block_out_channels=ctrl_block_out_channels, - cross_attention_dim=cross_attention_dim, - down_block_types=down_block_types, - sample_size=sample_size, - transformer_layers_per_block=transformer_layers_per_block, - upcast_attention=upcast_attention, - max_norm_num_groups=ctrl_max_norm_num_groups, + down_blocks = [] + up_blocks = [] + + # create down blocks + def left_shifted_iterator_pairs(iterable, keys=["in", "out"]): + """e.g. [0,1,2,3] -> [({"in":0,"out":0}, {"in":0,"out":1}, {"in":1,"out":2}, {"in":2,"out":3}]""" + left_shifted_iterable = iterable[0] + list(iterable[:-1]) + return [ + {keys[0]: o1, keys[1]: o2} + for o1,o2 in zip(left_shifted_iterable, iterable) + ] + + channels = {"base": left_shifted_iterator_pairs(block_out_channels), "ctrl": left_shifted_iterator_pairs(ctrl_block_out_channels)} + + for i, (down_block_type, b_channels, c_channels) in enumerate((down_block_types, channels["base"], channels["ctrl"])): + has_crossattn = "CrossAttn" in down_block_type + add_downsample = i==len(down_block_types)-1 + + down_blocks.append(ControlNetXSCrossAttnDownBlock2D( + base_in_channels = b_channels["in"], + base_out_channels = b_channels["out"], + ctrl_in_channels = c_channels["in"], + ctrl_out_channels = c_channels["out"], + temb_channels = base_model.config.time_embedding_dim, + max_norm_num_groups = ctrl_max_norm_num_groups.max_norm_num_groups, + has_crossattn = has_crossattn, + transformer_layers_per_block = transformer_layers_per_block[i], + num_attention_heads = num_attention_heads[i], + cross_attention_dim = cross_attention_dim[i], + add_downsample = add_downsample, + upcast_attention = upcast_attention + )) + + # create down blocks + self.mid_block = ControlNetXSCrossAttnMidBlock2D( + base_channels=block_out_channels[-1], + ctrl_channels=ctrl_block_out_channels[-1], + temb_channels = base_model.config.time_embedding_dim, + transformer_layers_per_block = transformer_layers_per_block[-1], + num_attention_heads = num_attention_heads[-1], + cross_attention_dim = cross_attention_dim[-1], + upcast_attention = upcast_attention, ) - @classmethod - def _unet_to_subblocks(cls, unet: UNet2DConditionModel): - """Decompose the down and up blocks of a UNet into subblocks, as required by UNetControlNetXSModel""" - down_subblocks = nn.ModuleList() - up_subblocks = nn.ModuleList() - - for block in unet.down_blocks: - # Each ResNet / Attention pair is a subblock - resnets = block.resnets - attentions = block.attentions if hasattr(block, "attentions") else [None] * len(resnets) - for r, a in zip(resnets, attentions): - down_subblocks.append(CrossAttnDownSubBlock2D.from_modules(r, a)) - # Each Downsampler is a subblock - if block.downsamplers is not None: - if len(block.downsamplers) != 1: - raise ValueError( - "ControlNet-XS currently only supports StableDiffusion and StableDiffusion-XL." - "Therefore each down block of the base model should have only 1 downsampler (if any)." - ) - down_subblocks.append(DownSubBlock2D.from_modules(block.downsamplers[0])) - - for block in unet.up_blocks: - # Each ResNet / Attention / Upsampler triple is a subblock - if block.upsamplers is not None: - if len(block.upsamplers) != 1: - raise ValueError( - "ControlNet-XS currently only supports StableDiffusion and StableDiffusion-XL." - "Therefore each up block of the base model should have only 1 upsampler (if any)." - ) - upsampler = block.upsamplers[0] - else: - upsampler = None - - resnets = block.resnets - attentions = block.attentions if hasattr(block, "attentions") else [None] * len(resnets) - upsamplers = [None] * (len(resnets) - 1) + [upsampler] - for r, a, u in zip(resnets, attentions, upsamplers): - up_subblocks.append(CrossAttnUpSubBlock2D.from_modules(r, a, u)) - - return down_subblocks, up_subblocks + # create up blocks + rev_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + rev_num_attention_heads = list(reversed(num_attention_heads)) + rev_cross_attention_dim = list(reversed(cross_attention_dim)) + rev_block_out_channels = list(reversed(block_out_channels)) + + for i, up_block_type in enumerate(up_block_types): + has_crossattn = "CrossAttn" in down_block_type + add_upsample = i>0 # todo umer: correct? + + up_blocks.append(ControlNetXSCrossAttnUpBlock2D(# todo umer + in_channels = 123456, + out_channels = 123456, + prev_output_channel = 123456, + ctrl_skip_channels = [123456, 123456], + temb_channels = base_model.config.time_embedding_dim, + has_crossattn = has_crossattn, + transformer_layers_per_block = rev_transformer_layers_per_block[-1], + num_attention_heads = rev_num_attention_heads[-1], + cross_attention_dim = rev_cross_attention_dim[-1], + add_upsample = add_upsample, + upcast_attention = upcast_attention, + )) + + self.down_bocks = nn.ModuleList(down_blocks) + self.up_bocks = nn.ModuleList(up_blocks) @classmethod def from_unet2d( @@ -922,30 +935,14 @@ def forward( return ControlNetXSOutput(sample=h_base) -def zero_module(module): - for p in module.parameters(): - nn.init.zeros_(p) - return module - - -def find_largest_factor(number, max_factor): - factor = max_factor - if factor >= number: - return number - while factor != 0: - residual = number % factor - if residual == 0: - return factor - factor -= 1 - - class ControlNetXSCrossAttnDownBlock2D(nn.Module): def __init__( self, - is_empty: bool = False, - in_channels: Optional[int] = None, - out_channels: Optional[int] = None, - temb_channels: Optional[int] = None, + base_in_channels: int, + base_out_channels: int, + ctrl_in_channels: int, + ctrl_out_channels: int, + temb_channels: int, max_norm_num_groups: Optional[int] = 32, has_crossattn=True, transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple[int]]]] = 1, @@ -959,6 +956,8 @@ def __init__( base_attentions = [] ctrl_resnets =[] ctrl_attentions = [] + ctrl_to_base = [] + base_to_ctrl = [] num_layers = 2 # only support sd + sdxl @@ -968,21 +967,27 @@ def __init__( transformer_layers_per_block = [transformer_layers_per_block] * num_layers for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels + base_in_channels = base_in_channels if i == 0 else base_out_channels + ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_in_channels + + # Before the resnet/attention application, information is concatted from base to control. + # Concat doesn't require change in number of channels + base_to_ctrl.append(make_zero_conv(base_in_channels, base_in_channels)) + base_resnets.append( ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, + in_channels=base_in_channels, + out_channels=base_out_channels, temb_channels=temb_channels, ) ) ctrl_resnets.append( ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, + in_channels=ctrl_in_channels, + out_channels=ctrl_in_channels, temb_channels=temb_channels, - groups=find_largest_factor(in_channels, max_factor=max_norm_num_groups), - groups_out=find_largest_factor(out_channels, max_factor=max_norm_num_groups), + groups=find_largest_factor(ctrl_in_channels, max_factor=max_norm_num_groups), + groups_out=find_largest_factor(ctrl_in_channels, max_factor=max_norm_num_groups), eps=1e-5, ) ) @@ -991,8 +996,8 @@ def __init__( base_attentions.append( Transformer2DModel( num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, + base_out_channels // num_attention_heads, + in_channels=base_out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, use_linear_projection=True, @@ -1002,44 +1007,75 @@ def __init__( ctrl_attentions.append( Transformer2DModel( num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, + ctrl_out_channels // num_attention_heads, + in_channels=ctrl_out_channels, num_layers=transformer_layers_per_block, cross_attention_dim=cross_attention_dim, use_linear_projection=True, upcast_attention=upcast_attention, - norm_num_groups=find_largest_factor(out_channels, max_factor=max_norm_num_groups), + norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups), ) ) - self.base_resnets = nn.ModuleList(base_resnets) - self.ctrl_resnets = nn.ModuleList(ctrl_resnets) - self.base_attentions = nn.ModuleList(base_attentions) if has_crossattn else [None]*num_layers - self.ctrl_attentions = nn.ModuleList(ctrl_attentions) if has_crossattn else [None]*num_layers + # After the resnet/attention application, information is added from control to base + # Addition requires change in number of channels + ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) if add_downsample: - self.base_downsamplers = Downsample2D(out_channels, use_conv=True, out_channels=out_channels, name="op") - self.ctrl_downsamplers = Downsample2D(out_channels, use_conv=True, out_channels=out_channels, name="op") + # Before the downsampler application, information is concatted from base to control + # Concat doesn't require change in number of channels + base_to_ctrl.append(make_zero_conv(base_out_channels, base_out_channels)) + + self.base_downsamplers = Downsample2D(base_out_channels, use_conv=True, out_channels=base_out_channels, name="op") + self.ctrl_downsamplers = Downsample2D(ctrl_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op") + + # After the downsampler application, information is added from control to base + # Addition requires change in number of channels + ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) else: self.base_downsamplers = None self.ctrl_downsamplers = None - # todo umer: connections b2c, c2b - self.b2c = None - self.c2b = None + self.base_resnets = nn.ModuleList(base_resnets) + self.ctrl_resnets = nn.ModuleList(ctrl_resnets) + self.base_attentions = nn.ModuleList(base_attentions) if has_crossattn else [None]*num_layers + self.ctrl_attentions = nn.ModuleList(ctrl_attentions) if has_crossattn else [None]*num_layers + self.base_to_ctrl = nn.ModuleList(base_to_ctrl) + self.ctrl_to_base = nn.ModuleList(ctrl_to_base) self.gradient_checkpointing = False @classmethod - def from_modules(cls, resnet: ResnetBlock2D, attention: Optional[Transformer2DModel] = None): - """Create empty subblock and set resnet and attention manually""" - # todo umer - subblock = cls(is_empty=True) - subblock.resnet = resnet - subblock.attention = attention - subblock.in_channels = resnet.in_channels - subblock.out_channels = resnet.out_channels - return subblock + def from_modules( + cls, + base_resnets: List[ResnetBlock2D], ctrl_resnets: List[ResnetBlock2D], + base_to_control_connections: List[nn.Conv2d], control_to_base_connections: List[nn.Conv2d], + base_attentions: Optional[List[Transformer2DModel]] = None, ctrl_attentions: Optional[List[Transformer2DModel]] = None, + base_downsampler: Optional[List[Transformer2DModel]] = None, ctrl_downsampler: Optional[List[Transformer2DModel]] = None,): + """todo umer""" + block = cls( + in_channels = None, + out_channels = None, + temb_channels = None, + max_norm_num_groups = 32, + has_crossattn = True, + transformer_layers_per_block = 1, + num_attention_heads = 1, + cross_attention_dim = 1024, + add_downsample = True, + upcast_attention = False, + ) + + block.base_resnets = base_resnets + block.base_attentions = base_attentions + block.ctrl_resnets = ctrl_resnets + block.ctrl_attentions = ctrl_attentions + block.b2c = base_to_control_connections + block.c2b = control_to_base_connections + block.base_downsampler = base_downsampler + block.ctrl_downsampler = ctrl_downsampler + + return block def forward( self, @@ -1065,7 +1101,7 @@ def forward( base_blocks = list(zip(self.base_resnets, self.base_attentions)) ctrl_blocks = list(zip(self.ctrl_resnets, self.ctrl_attentions)) - for (b_res, b_attn), (c_res, c_attn), b2c, c2b in zip(base_blocks, ctrl_blocks, self.b2c, self.c2b): + for (b_res, b_attn), (c_res, c_attn), b2c, c2b in zip(base_blocks, ctrl_blocks, self.base_to_ctrl, self.ctrl_to_base): if self.training and self.gradient_checkpointing: raise NotImplementedError("todo umer") else: @@ -1103,6 +1139,9 @@ def forward( ctrl_output_states = ctrl_output_states + (h_ctrl,) if self.base_downsamplers is not None: # if we have a base_downsampler, then also a ctrl_downsampler + b2c = self.base_to_ctrl[-1] + c2b = self.ctrl_to_base[-1] + # concat base -> ctrl h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) # apply base subblock @@ -1118,24 +1157,119 @@ def forward( return h_base, h_ctrl,base_output_states, ctrl_output_states -class ControlNetXSCrossAttnUplock2D(nn.Module): +class ControlNetXSCrossAttnMidBlock2D(nn.Module): def __init__( self, - is_empty: bool = False, - in_channels: Optional[int] = None, - out_channels: Optional[int] = None, - prev_output_channel: Optional[int] = None, + base_channels: int, + ctrl_channels: int, temb_channels: Optional[int] = None, - has_crossattn=True, - transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple[int]]]] = 1, + transformer_layers_per_block: int = 1, num_attention_heads: Optional[int] = 1, cross_attention_dim: Optional[int] = 1024, + upcast_attention: bool = False, + ): + super().__init__() + + # Before the midblock application, information is concatted from base to control. + # Concat doesn't require change in number of channels + self.base_to_ctrl = make_zero_conv(base_channels, base_channels) + + self.base_midblock = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block, + in_channels=base_channels, + temb_channels=temb_channels, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + use_linear_projection=True, + upcast_attention=upcast_attention + ) + self.ctrl_midblock = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block, + in_channels=ctrl_channels + base_channels, + out_channels=ctrl_channels, + temb_channels=temb_channels, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, # todo umer: n_attn_heads different for base / ctrl? + use_linear_projection=True, + upcast_attention=upcast_attention + ) + + # After the midblock application, information is added from control to base + # Addition requires change in number of channels + self.ctrl_to_base = make_zero_conv(ctrl_channels, base_channels) + + self.gradient_checkpointing = False + + @classmethod + def from_modules( + cls, + resnet: ResnetBlock2D, + attention: Optional[Transformer2DModel] = None, + upsampler: Optional[Upsample2D] = None, + ): + """Create empty subblock and set resnet, attention and upsampler manually""" + # todo umer + subblock = cls() + subblock.resnet = resnet + subblock.attention = attention + subblock.upsampler = upsampler + subblock.in_channels = resnet.in_channels + subblock.out_channels = resnet.out_channels + return subblock + + def forward( + self, + hidden_states_base: torch.FloatTensor, + hidden_states_ctrl: torch.FloatTensor, + conditioning_scale: Optional[float] = 1.0, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: # todo umer: output type hint correct? + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + h_base = hidden_states_base + h_ctrl = hidden_states_ctrl + + joint_args = { + "temb": temb, + "encoder_hidden_states": encoder_hidden_states, + "attention_mask": attention_mask, + "cross_attention_kwargs": cross_attention_kwargs, + "encoder_attention_mask": encoder_attention_mask, + } + + h_ctrl = torch.cat([h_ctrl, self.base_to_ctrl(h_base)], dim=1) # concat base -> ctrl + h_base = self.base_midblock(h_base, **joint_args) # apply base mid block + h_ctrl = self.ctrl_midblock(h_ctrl, **joint_args) # apply ctrl mid block + h_base = h_base + self.ctrl_to_base(h_ctrl) * conditioning_scale # add ctrl -> base + + return h_base, h_ctrl + + +class ControlNetXSCrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + ctrl_skip_channels: List[int], + temb_channels: int, + has_crossattn=True, + transformer_layers_per_block: int = 1, + num_attention_heads: int = 1, + cross_attention_dim: int = 1024, add_upsample: bool = True, upcast_attention: bool = False, ): super().__init__() resnets = [] attentions = [] + ctrl_to_base = [] num_layers = 3 # only support sd + sdxl @@ -1149,6 +1283,8 @@ def __init__( res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels + ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels)) + resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, @@ -1172,15 +1308,13 @@ def __init__( self.resnets = nn.ModuleList(resnets) self.attentions = nn.ModuleList(attentions) if has_crossattn else [None]*num_layers + self.ctrl_to_base = nn.ModuleList(ctrl_to_base) if add_upsample: self.upsamplers = Upsample2D(out_channels, use_conv=True, out_channels=out_channels) else: self.upsamplers = None - # todo umer: c2b - self.c2b = None - self.gradient_checkpointing = False @classmethod @@ -1205,6 +1339,7 @@ def forward( hidden_states: torch.FloatTensor, res_hidden_states_tuple_base: Tuple[torch.FloatTensor, ...], res_hidden_states_tuple_cltr: Tuple[torch.FloatTensor, ...], + conditioning_scale: Optional[float] = 1.0, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -1216,8 +1351,19 @@ def forward( if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - for resnet, attn, c2b, res_h_base, res_h_ctrl in zip(self.resnets, self.attentions, self.c2b, reversed(res_hidden_states_tuple_base), reversed(res_hidden_states_tuple_cltr)): - hidden_states += c2b(res_h_ctrl) + # In ControlNet-XS, the last resnet/attention and the upsampler are treated as a group. + # So we separate them to pass information from ctrl to base correctly. + if self.upsamplers is None: + resnets_without_upsampler = self.resnets + attn_without_upsampler = self.attentions + else: + resnets_without_upsampler = self.resnets[:-1] + attn_without_upsampler = self.attentions[:-1] + resnet_with_upsampler = self.resnets[-1] + attn_with_upsampler = self.attentions[-1] + + for resnet, attn, c2b, res_h_base, res_h_ctrl in zip(resnets_without_upsampler, attn_without_upsampler, self.ctrl_to_base, reversed(res_hidden_states_tuple_base), reversed(res_hidden_states_tuple_cltr)): + hidden_states += c2b(res_h_ctrl) * conditioning_scale hidden_states = torch.cat([hidden_states, res_h_base], dim=1) if self.training and self.gradient_checkpointing: @@ -1235,12 +1381,44 @@ def forward( )[0] if self.upsampler is not None: - c2b = self.c2b[-1] + c2b = self.ctrl_to_base[-1] res_h_base = res_hidden_states_tuple_base[0] res_h_ctrl = res_hidden_states_tuple_cltr[0] - hidden_states += c2b(res_h_ctrl) + hidden_states += c2b(res_h_ctrl) * conditioning_scale hidden_states = torch.cat([hidden_states, res_h_base], dim=1) + + hidden_states = resnet_with_upsampler(hidden_states, temb) + if attn_with_upsampler is not None: + hidden_states = attn_with_upsampler( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] hidden_states = self.upsampler(hidden_states, upsample_size) return hidden_states + + +def make_zero_conv(in_channels, out_channels=None): + return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0)) + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +def find_largest_factor(number, max_factor): + factor = max_factor + if factor >= number: + return number + while factor != 0: + residual = number % factor + if residual == 0: + return factor + factor -= 1 From 56d69fe5288a6d0cf93d901f15c89f945b711403 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Sun, 24 Mar 2024 18:20:48 +0100 Subject: [PATCH 56/75] checkin 24 Mar --- src/diffusers/models/controlnet_xs.py | 185 +++++++++++++++----------- 1 file changed, 110 insertions(+), 75 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index a740240a6c86..a885b5576c6a 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass +from math import gcd from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -506,7 +507,6 @@ def __init__( addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, upcast_attention: bool = True, - time_embedding_dim: Optional[int] = None, time_cond_proj_dim: Optional[int] = None, projection_class_embeddings_input_dim: Optional[int] = None, # additional controlnet configs @@ -516,7 +516,7 @@ def __init__( ctrl_conditioning_channel_order: str = "rgb", ctrl_learn_time_embedding: bool = False, ctrl_block_out_channels: Tuple[int] = (4, 8, 16, 16), - ctrl_attention_head_dim: Union[int, Tuple[int]] = 4, + ctrl_attention_head_dim: Union[int, Tuple[int]] = 4, # todo umer: # attn heads or dim attn heads? ctrl_max_norm_num_groups: int = 32, ): super().__init__() @@ -533,9 +533,8 @@ def repeat_if_not_list(value, repetitions): transformer_layers_per_block = repeat_if_not_list(transformer_layers_per_block, repetitions=len(down_block_types)) cross_attention_dim = repeat_if_not_list(cross_attention_dim, repetitions=len(down_block_types)) - num_attention_heads = repeat_if_not_list(num_attention_heads, repetitions=len(down_block_types)) - - time_embedding_dim = time_embedding_dim or block_out_channels[0] * 4 + base_num_attention_heads = repeat_if_not_list(num_attention_heads, repetitions=len(down_block_types)) + ctrl_attention_head_dim = repeat_if_not_list(ctrl_attention_head_dim, repetitions=len(down_block_types)) # Create UNet and decompose it into subblocks, which we then save base_model = UNet2DConditionModel( @@ -549,7 +548,6 @@ def repeat_if_not_list(value, repetitions): attention_head_dim=num_attention_heads, use_linear_projection=True, upcast_attention=upcast_attention, - time_embedding_dim=time_embedding_dim, class_embed_type=class_embed_type, addition_embed_type=addition_embed_type, time_cond_proj_dim=time_cond_proj_dim, @@ -559,6 +557,9 @@ def repeat_if_not_list(value, repetitions): self.in_channels = 4 + time_embed_input_dim = block_out_channels[0] + time_embed_dim = block_out_channels[0] * 4 + self.base_time_proj = base_model.time_proj self.base_time_embedding = base_model.time_embedding self.base_class_embedding = base_model.class_embedding @@ -571,66 +572,93 @@ def repeat_if_not_list(value, repetitions): self.base_conv_act = base_model.conv_act self.base_conv_out = base_model.conv_out + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + conditioning_embedding_channels=ctrl_block_out_channels[0], + block_out_channels=ctrl_conditioning_embedding_out_channels, + conditioning_channels=ctrl_conditioning_channels, + ) + self.ctrl_conv_in = nn.Conv2d(4, ctrl_block_out_channels[0], kernel_size=3, padding=1) + self.ctrl_time_embedding = TimestepEmbedding(in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim) + + self.control_to_base_for_conv_in = make_zero_conv(ctrl_block_out_channels[0], block_out_channels[0]) + down_blocks = [] up_blocks = [] - # create down blocks + # # Create down blocks def left_shifted_iterator_pairs(iterable, keys=["in", "out"]): """e.g. [0,1,2,3] -> [({"in":0,"out":0}, {"in":0,"out":1}, {"in":1,"out":2}, {"in":2,"out":3}]""" - left_shifted_iterable = iterable[0] + list(iterable[:-1]) + left_shifted_iterable = [iterable[0]] + list(iterable[:-1]) return [ {keys[0]: o1, keys[1]: o2} for o1,o2 in zip(left_shifted_iterable, iterable) ] - channels = {"base": left_shifted_iterator_pairs(block_out_channels), "ctrl": left_shifted_iterator_pairs(ctrl_block_out_channels)} + down_block_channels = {"base": left_shifted_iterator_pairs(block_out_channels), "ctrl": left_shifted_iterator_pairs(ctrl_block_out_channels)} - for i, (down_block_type, b_channels, c_channels) in enumerate((down_block_types, channels["base"], channels["ctrl"])): + for i, (down_block_type, b_channels, c_channels) in enumerate(zip(down_block_types, down_block_channels["base"], down_block_channels["ctrl"])): has_crossattn = "CrossAttn" in down_block_type - add_downsample = i==len(down_block_types)-1 + add_downsample = i (3,2,1) + # into -> [{"in": 3, "out": 3}, {"in": 3, "out": 2}, {"in": 2, "out": 1}] + rev_down_block_channels = left_shifted_iterator_pairs(list(reversed(block_out_channels))) for i, up_block_type in enumerate(up_block_types): - has_crossattn = "CrossAttn" in down_block_type - add_upsample = i>0 # todo umer: correct? - - up_blocks.append(ControlNetXSCrossAttnUpBlock2D(# todo umer - in_channels = 123456, - out_channels = 123456, - prev_output_channel = 123456, - ctrl_skip_channels = [123456, 123456], - temb_channels = base_model.config.time_embedding_dim, + has_crossattn = "CrossAttn" in up_block_type + add_upsample = i0 else in_channels + ctrl_skip_channels_ = [ctrl_skip_channels.pop() for _ in range(3)] + + up_blocks.append(ControlNetXSCrossAttnUpBlock2D( + in_channels = in_channels, + out_channels = out_channels, + prev_output_channel = prev_output_channel, + ctrl_skip_channels = ctrl_skip_channels_, + temb_channels = time_embed_dim, has_crossattn = has_crossattn, transformer_layers_per_block = rev_transformer_layers_per_block[-1], num_attention_heads = rev_num_attention_heads[-1], @@ -639,8 +667,12 @@ def left_shifted_iterator_pairs(iterable, keys=["in", "out"]): upcast_attention = upcast_attention, )) - self.down_bocks = nn.ModuleList(down_blocks) - self.up_bocks = nn.ModuleList(up_blocks) + self.down_blocks = nn.ModuleList(down_blocks) + self.up_blocks = nn.ModuleList(up_blocks) + + # todo umer: create control_addon.conv_in + # todo umer: create ctrl_time_embedding + # tood umer: creatae b2c for conv->down0 @classmethod def from_unet2d( @@ -727,6 +759,7 @@ def from_unet2d( return model def freeze_unet2d_params(self) -> None: + # todo umer """Freeze the weights of just the UNet2DConditionModel, and leave the ControlNetXSAddon unfrozen for fine tuning. """ @@ -740,7 +773,7 @@ def freeze_unet2d_params(self) -> None: @torch.no_grad() def _check_if_vae_compatible(self, vae: AutoencoderKL): - condition_downscale_factor = 2 ** (len(self.control_addon.config.conditioning_embedding_out_channels) - 1) + condition_downscale_factor = 2 ** (len(self.config.ctrl_conditioning_embedding_out_channels) - 1) vae_downscale_factor = 2 ** (len(vae.config.block_out_channels) - 1) compatible = condition_downscale_factor == vae_downscale_factor return compatible, condition_downscale_factor, vae_downscale_factor @@ -804,7 +837,7 @@ def forward( """ # check channel order - if self.control_addon.config.conditioning_channel_order == "bgr": + if self.config.ctrl_conditioning_channel_order == "bgr": controlnet_cond = torch.flip(controlnet_cond, dims=[1]) # prepare attention_mask @@ -837,9 +870,9 @@ def forward( t_emb = t_emb.to(dtype=sample.dtype) if self.config.ctrl_learn_time_embedding: - ctrl_temb = self.control_addon.time_embedding(t_emb, timestep_cond) + ctrl_temb = self.ctrl_time_embedding(t_emb, timestep_cond) base_temb = self.base_time_embedding(t_emb, timestep_cond) - interpolation_param = self.control_addon.config.time_embedding_mix**0.3 + interpolation_param = self.config.time_embedding_mix**0.3 temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param) else: @@ -892,37 +925,35 @@ def forward( hs_base, hs_ctrl = [], [] # Cross Control - # Let's first define variables to shorten notation - - guided_hint = self.control_addon.controlnet_cond_embedding(controlnet_cond) + guided_hint = self.controlnet_cond_embedding(controlnet_cond) # 1 - conv in & down h_base = self.base_conv_in(h_base) - h_ctrl = self.control_addon.conv_in(h_ctrl) + h_ctrl = self.ctrl_conv_in(h_ctrl) if guided_hint is not None: h_ctrl += guided_hint - h_base = h_base + self.pre_zero_convs_c2b(h_ctrl) * conditioning_scale # add ctrl -> base # todo umer: define self.pre_zero_convs_c2b + h_base = h_base + self.control_to_base_for_conv_in(h_ctrl) * conditioning_scale # add ctrl -> base hs_base.append(h_base) hs_ctrl.append(h_ctrl) - for down in self.down_blocks: # todo umer: define self.down_blocks - h_base,h_ctrl,residual_hb,residual_hc = down(h_base,h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) + for down in self.down_blocks: + h_base,h_ctrl,residual_hb,residual_hc = down(h_base,h_ctrl, temb, cemb, conditioning_scale, cross_attention_kwargs, attention_mask) hs_base.extend(residual_hb) hs_ctrl.extend(residual_hc) # 2 - mid - h_base,h_ctrl = self.mid_block(h_base,h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # todo umer: define self.mid_block + h_base,h_ctrl = self.mid_block(h_base,h_ctrl, temb, cemb, conditioning_scale, cross_attention_kwargs, attention_mask) # 3 - up - for up in self.up_blocks: # todo umer: define self.up_blocks + for up in self.up_blocks: n_resnets = len(up.resnets) skips_hb = hs_base[-n_resnets:] skips_hc = hs_ctrl[-n_resnets:] hs_base = hs_base[:-n_resnets] hs_ctrl = hs_ctrl[:-n_resnets] - h_base = up(h_base,h_ctrl,skips_hb,skips_hc,temb, cemb, attention_mask, cross_attention_kwargs) + h_base = up(h_base,skips_hb,skips_hc,temb, cemb, conditioning_scale, cross_attention_kwargs, attention_mask) # 4 - conv out h_base = self.base_conv_norm_out(h_base) @@ -945,8 +976,9 @@ def __init__( temb_channels: int, max_norm_num_groups: Optional[int] = 32, has_crossattn=True, - transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple[int]]]] = 1, - num_attention_heads: Optional[int] = 1, + transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1, + base_num_attention_heads: Optional[int] = 1, + ctrl_num_attention_heads: Optional[int] = 1, cross_attention_dim: Optional[int] = 1024, add_downsample: bool = True, upcast_attention: Optional[bool] = False, @@ -961,14 +993,12 @@ def __init__( num_layers = 2 # only support sd + sdxl - self.has_cross_attention = has_crossattn - self.num_attention_heads = num_attention_heads if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * num_layers for i in range(num_layers): base_in_channels = base_in_channels if i == 0 else base_out_channels - ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_in_channels + ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels # Before the resnet/attention application, information is concatted from base to control. # Concat doesn't require change in number of channels @@ -983,11 +1013,11 @@ def __init__( ) ctrl_resnets.append( ResnetBlock2D( - in_channels=ctrl_in_channels, - out_channels=ctrl_in_channels, + in_channels=ctrl_in_channels + base_in_channels, # information from base is concatted to ctrl + out_channels=ctrl_out_channels, temb_channels=temb_channels, - groups=find_largest_factor(ctrl_in_channels, max_factor=max_norm_num_groups), - groups_out=find_largest_factor(ctrl_in_channels, max_factor=max_norm_num_groups), + groups=find_largest_factor(ctrl_in_channels + base_in_channels, max_factor=max_norm_num_groups), + groups_out=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups), eps=1e-5, ) ) @@ -995,8 +1025,8 @@ def __init__( if has_crossattn: base_attentions.append( Transformer2DModel( - num_attention_heads, - base_out_channels // num_attention_heads, + base_num_attention_heads, + base_out_channels // base_num_attention_heads, in_channels=base_out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, @@ -1006,10 +1036,10 @@ def __init__( ) ctrl_attentions.append( Transformer2DModel( - num_attention_heads, - ctrl_out_channels // num_attention_heads, + ctrl_num_attention_heads, + ctrl_out_channels // ctrl_num_attention_heads, in_channels=ctrl_out_channels, - num_layers=transformer_layers_per_block, + num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, use_linear_projection=True, upcast_attention=upcast_attention, @@ -1027,7 +1057,7 @@ def __init__( base_to_ctrl.append(make_zero_conv(base_out_channels, base_out_channels)) self.base_downsamplers = Downsample2D(base_out_channels, use_conv=True, out_channels=base_out_channels, name="op") - self.ctrl_downsamplers = Downsample2D(ctrl_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op") + self.ctrl_downsamplers = Downsample2D(ctrl_out_channels + base_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op") # After the downsampler application, information is added from control to base # Addition requires change in number of channels @@ -1081,9 +1111,9 @@ def forward( self, hidden_states_base: torch.FloatTensor, hidden_states_ctrl: torch.FloatTensor, - conditioning_scale: Optional[float] = 1.0, - temb: Optional[torch.FloatTensor] = None, + temb: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, + conditioning_scale: Optional[float] = 1.0, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, @@ -1163,8 +1193,10 @@ def __init__( base_channels: int, ctrl_channels: int, temb_channels: Optional[int] = None, + max_norm_num_groups: Optional[int] = 32, transformer_layers_per_block: int = 1, - num_attention_heads: Optional[int] = 1, + base_num_attention_heads: Optional[int] = 1, + ctrl_num_attention_heads: Optional[int] = 1, cross_attention_dim: Optional[int] = 1024, upcast_attention: bool = False, ): @@ -1179,17 +1211,20 @@ def __init__( in_channels=base_channels, temb_channels=temb_channels, cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads, + num_attention_heads=base_num_attention_heads, use_linear_projection=True, upcast_attention=upcast_attention ) + self.ctrl_midblock = UNetMidBlock2DCrossAttn( transformer_layers_per_block=transformer_layers_per_block, in_channels=ctrl_channels + base_channels, out_channels=ctrl_channels, temb_channels=temb_channels, + # number or norm groups must divide both in_channels and out_channels + resnet_groups=find_largest_factor(gcd(ctrl_channels, ctrl_channels + base_channels), max_norm_num_groups), cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads, # todo umer: n_attn_heads different for base / ctrl? + num_attention_heads=ctrl_num_attention_heads, use_linear_projection=True, upcast_attention=upcast_attention ) @@ -1221,9 +1256,9 @@ def forward( self, hidden_states_base: torch.FloatTensor, hidden_states_ctrl: torch.FloatTensor, + temb: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, conditioning_scale: Optional[float] = 1.0, - temb: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, @@ -1280,7 +1315,7 @@ def __init__( transformer_layers_per_block = [transformer_layers_per_block] * num_layers for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + res_skip_channels = in_channels if (i < num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels)) @@ -1337,14 +1372,14 @@ def from_modules( def forward( self, hidden_states: torch.FloatTensor, - res_hidden_states_tuple_base: Tuple[torch.FloatTensor, ...], - res_hidden_states_tuple_cltr: Tuple[torch.FloatTensor, ...], - conditioning_scale: Optional[float] = 1.0, - temb: Optional[torch.FloatTensor] = None, + res_hidden_states_tuple_base: Tuple[torch.FloatTensor, ...], # todo umer: why ... in type hint? + res_hidden_states_tuple_cltr: Tuple[torch.FloatTensor, ...], # todo umer: why ... in type hint? + temb: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, + conditioning_scale: Optional[float] = 1.0, cross_attention_kwargs: Optional[Dict[str, Any]] = None, - upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: # todo umer: output type hint correct? if cross_attention_kwargs is not None: @@ -1380,7 +1415,7 @@ def forward( return_dict=False, )[0] - if self.upsampler is not None: + if self.upsamplers is not None: c2b = self.ctrl_to_base[-1] res_h_base = res_hidden_states_tuple_base[0] res_h_ctrl = res_hidden_states_tuple_cltr[0] @@ -1398,7 +1433,7 @@ def forward( encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - hidden_states = self.upsampler(hidden_states, upsample_size) + hidden_states = self.upsamplers(hidden_states, upsample_size) return hidden_states From 395716ad9247df7ab199f85baa621e9906c99ab7 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Sun, 24 Mar 2024 21:30:47 +0100 Subject: [PATCH 57/75] Created init for UNetCnxs and CnxsAddon --- src/diffusers/models/controlnet_xs.py | 394 ++++++++++++++++---------- 1 file changed, 240 insertions(+), 154 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index a885b5576c6a..d2d03dd84d66 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -195,6 +195,144 @@ def gather_base_subblock_sizes(blocks_sizes: List[int]): "up - in": up_in, } + @staticmethod + def get_down_block( + base_in_channels: int, + base_out_channels: int, + ctrl_in_channels: int, + ctrl_out_channels: int, + temb_channels: int, + max_norm_num_groups: Optional[int] = 32, + has_crossattn=True, + transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1, + num_attention_heads: Optional[int] = 1, + cross_attention_dim: Optional[int] = 1024, + add_downsample: bool = True, + upcast_attention: Optional[bool] = False, + ): + num_layers = 2 # only support sd + sdxl + + resnets = [] + attentions = [] + ctrl_to_base = [] + base_to_ctrl = [] + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels + + # Before the resnet/attention application, information is concatted from base to control. + # Concat doesn't require change in number of channels + base_to_ctrl.append(make_zero_conv(base_in_channels, base_in_channels)) + + resnets.append( + ResnetBlock2D( + in_channels=ctrl_in_channels + base_in_channels, # information from base is concatted to ctrl + out_channels=ctrl_out_channels, + temb_channels=temb_channels, + groups=find_largest_factor(ctrl_in_channels + base_in_channels, max_factor=max_norm_num_groups), + groups_out=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups), + eps=1e-5, + ) + ) + + if has_crossattn: + attentions.append( + Transformer2DModel( + num_attention_heads, + ctrl_out_channels // num_attention_heads, + in_channels=ctrl_out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + use_linear_projection=True, + upcast_attention=upcast_attention, + norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups), + ) + ) + + # After the resnet/attention application, information is added from control to base + # Addition requires change in number of channels + ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) + + if add_downsample: + # Before the downsampler application, information is concatted from base to control + # Concat doesn't require change in number of channels + base_to_ctrl.append(make_zero_conv(base_out_channels, base_out_channels)) + + downsamplers = Downsample2D(ctrl_out_channels + base_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op") + + # After the downsampler application, information is added from control to base + # Addition requires change in number of channels + ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) + else: + downsamplers = None + + module_dict = nn.ModuleDict({ + "resnets": nn.ModuleList(resnets), + "base_to_ctrl": nn.ModuleList(base_to_ctrl), + "ctrl_to_base": nn.ModuleList(ctrl_to_base), + }) + if has_crossattn: + module_dict["attentions"] = nn.ModuleList(attentions) + if downsamplers is not None: + module_dict["downsamplers"] = downsamplers + + return module_dict + + @staticmethod + def get_mid_block( + base_channels: int, + ctrl_channels: int, + temb_channels: Optional[int] = None, + max_norm_num_groups: Optional[int] = 32, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = 1, + cross_attention_dim: Optional[int] = 1024, + upcast_attention: bool = False, + ): + # Before the midblock application, information is concatted from base to control. + # Concat doesn't require change in number of channels + base_to_ctrl = make_zero_conv(base_channels, base_channels) + + midblock = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block, + in_channels=ctrl_channels + base_channels, + out_channels=ctrl_channels, + temb_channels=temb_channels, + # number or norm groups must divide both in_channels and out_channels + resnet_groups=find_largest_factor(gcd(ctrl_channels, ctrl_channels + base_channels), max_norm_num_groups), + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + use_linear_projection=True, + upcast_attention=upcast_attention + ) + + # After the midblock application, information is added from control to base + # Addition requires change in number of channels + ctrl_to_base = make_zero_conv(ctrl_channels, base_channels) + + return nn.ModuleDict({ + "base_to_ctrl": base_to_ctrl, + "midblock": midblock, + "ctrl_to_base": ctrl_to_base + }) + + @staticmethod + def get_up_connections( + out_channels: int, + prev_output_channel: int, + ctrl_skip_channels: List[int], + ): + ctrl_to_base = [] + num_layers = 3 # only support sd + sdxl + for i in range(num_layers): + resnet_in_channels = prev_output_channel if i == 0 else out_channels + ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels)) + + return nn.ModuleList(ctrl_to_base) + @classmethod def from_unet( cls, @@ -272,13 +410,9 @@ def __init__( time_embedding_dim: Optional[int] = 1280, time_embedding_mix: float = 1.0, learn_time_embedding: bool = False, - channels_base: Dict[str, List[Tuple[int]]] = { - "down - out": [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280], - "mid - out": 1280, - "up - in": [1280, 1280, 1280, 1280, 1280, 1280, 1280, 640, 640, 640, 320, 320], - }, attention_head_dim: Union[int, Tuple[int]] = 4, block_out_channels: Tuple[int] = (4, 8, 16, 16), + base_block_out_channels: Tuple[int] = (320, 640, 1280, 1280), cross_attention_dim: int = 1024, down_block_types: Tuple[str] = ( "CrossAttnDownBlock2D", @@ -311,20 +445,27 @@ def __init__( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) - if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + transformer_layers_per_block = repeat_if_not_list(transformer_layers_per_block, repetitions=len(down_block_types)) + cross_attention_dim = repeat_if_not_list(cross_attention_dim, repetitions=len(down_block_types)) + num_attention_heads = repeat_if_not_list(num_attention_heads, repetitions=len(down_block_types)) # todo umer: im using # attn heads & dim attn heads. should only be one. + attention_head_dim = repeat_if_not_list(attention_head_dim, repetitions=len(down_block_types)) + + if len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) - if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + if len(attention_head_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." ) - elif isinstance(attention_head_dim, int): - attention_head_dim = [attention_head_dim] * len(down_block_types) - # input - self.conv_in = nn.Conv2d(4, block_out_channels[0], kernel_size=3, padding=1) + # 5 - Create conditioning hint embedding + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + ) # time if learn_time_embedding: @@ -334,116 +475,70 @@ def __init__( self.time_embed_act = None - self.down_subblocks = nn.ModuleList([]) - self.up_subblocks = nn.ModuleList([]) + self.down_blocks = nn.ModuleList([]) + self.up_connections = nn.ModuleList([]) - if isinstance(num_attention_heads, int): - num_attention_heads = (num_attention_heads,) * len(down_block_types) - - if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + # input + self.conv_in = nn.Conv2d(4, block_out_channels[0], kernel_size=3, padding=1) + self.control_to_base_for_conv_in = make_zero_conv(block_out_channels[0], base_block_out_channels[0]) # down - output_channel = block_out_channels[0] - subblock_counter = 0 - + base_out_channels = base_block_out_channels[0] + ctrl_out_channels = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - use_crossattention = down_block_type == "CrossAttnDownBlock2D" - - self.down_subblocks.append( - CrossAttnDownSubBlock2D( - has_crossattn=use_crossattention, - in_channels=input_channel + channels_base["down - out"][subblock_counter], - out_channels=output_channel, - temb_channels=time_embedding_dim, - transformer_layers_per_block=transformer_layers_per_block[i], - num_attention_heads=num_attention_heads[i], - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - max_norm_num_groups=max_norm_num_groups, - ) - ) - subblock_counter += 1 - self.down_subblocks.append( - CrossAttnDownSubBlock2D( - has_crossattn=use_crossattention, - in_channels=output_channel + channels_base["down - out"][subblock_counter], - out_channels=output_channel, - temb_channels=time_embedding_dim, - transformer_layers_per_block=transformer_layers_per_block[i], - num_attention_heads=num_attention_heads[i], - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - max_norm_num_groups=max_norm_num_groups, - ) - ) - subblock_counter += 1 - if i < len(down_block_types) - 1: - self.down_subblocks.append( - DownSubBlock2D( - in_channels=output_channel + channels_base["down - out"][subblock_counter], - out_channels=output_channel, - ) - ) - subblock_counter += 1 + base_in_channels = base_out_channels + base_out_channels = base_block_out_channels[i] + ctrl_in_channels = ctrl_out_channels + ctrl_out_channels = block_out_channels[i] + has_crossattn = "CrossAttn" in down_block_type + is_final_block = i==len(down_block_types)-1 + + self.down_blocks.append(ControlNetXSAddon.get_down_block( + base_in_channels = base_in_channels, + base_out_channels = base_out_channels, + ctrl_in_channels = ctrl_in_channels, + ctrl_out_channels = ctrl_out_channels, + temb_channels = time_embedding_dim, + max_norm_num_groups = max_norm_num_groups, + has_crossattn = has_crossattn, + transformer_layers_per_block = transformer_layers_per_block[i], + num_attention_heads = attention_head_dim[i], + cross_attention_dim = cross_attention_dim[i], + add_downsample = not is_final_block, + upcast_attention = upcast_attention + )) # mid - mid_in_channels = block_out_channels[-1] + channels_base["down - out"][subblock_counter] - mid_out_channels = block_out_channels[-1] - - self.mid_block = UNetMidBlock2DCrossAttn( - transformer_layers_per_block=transformer_layers_per_block[-1], - in_channels=mid_in_channels, - out_channels=mid_out_channels, - temb_channels=time_embedding_dim, - resnet_eps=1e-05, - cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads[-1], - resnet_groups=find_largest_factor(mid_in_channels, max_norm_num_groups), - resnet_groups_out=find_largest_factor(mid_out_channels, max_norm_num_groups), - use_linear_projection=True, - upcast_attention=upcast_attention, + self.mid_block = ControlNetXSAddon.get_mid_block( + base_channels=base_block_out_channels[-1], + ctrl_channels=block_out_channels[-1], + temb_channels = time_embedding_dim, + transformer_layers_per_block = transformer_layers_per_block[-1], + num_attention_heads = attention_head_dim[-1], + cross_attention_dim = cross_attention_dim[-1], + upcast_attention = upcast_attention, ) - # 3 - Gather Channel Sizes - channels_ctrl = { - "down - out": [self.conv_in.out_channels] + [s.out_channels for s in self.down_subblocks], - "mid - out": self.down_subblocks[-1].out_channels, - } - - # 4 - Build connections between base and control model - # b2c = base -> ctrl ; c2b = ctrl -> base - self.down_zero_convs_b2c = nn.ModuleList([]) - self.down_zero_convs_c2b = nn.ModuleList([]) - self.mid_zero_convs_c2b = nn.ModuleList([]) - self.up_zero_convs_c2b = nn.ModuleList([]) - - # 4.1 - Connections from base encoder to ctrl encoder - # As the information is concatted to ctrl, the channels sizes don't change. - for c in channels_base["down - out"]: - self.down_zero_convs_b2c.append(make_zero_conv(c, c)) - - # 4.2 - Connections from ctrl encoder to base encoder - # As the information is added to base, the out-channels need to match base. - for ch_base, ch_ctrl in zip(channels_base["down - out"], channels_ctrl["down - out"]): - self.down_zero_convs_c2b.append(make_zero_conv(ch_ctrl, ch_base)) + # up + # The skip connection channels are the output of the conv_in and of all the down subblocks + ctrl_skip_channels = [block_out_channels[0]] + for i, out_channels in enumerate(block_out_channels): + number_of_subblocks = 3 if i [({"in":0,"out":0}, {"in":0,"out":1}, {"in":1,"out":2}, {"in":2,"out":3}]""" - left_shifted_iterable = [iterable[0]] + list(iterable[:-1]) - return [ - {keys[0]: o1, keys[1]: o2} - for o1,o2 in zip(left_shifted_iterable, iterable) - ] - - down_block_channels = {"base": left_shifted_iterator_pairs(block_out_channels), "ctrl": left_shifted_iterator_pairs(ctrl_block_out_channels)} - - for i, (down_block_type, b_channels, c_channels) in enumerate(zip(down_block_types, down_block_channels["base"], down_block_channels["ctrl"])): + base_out_channels = block_out_channels[0] + ctrl_out_channels = ctrl_block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + base_in_channels = base_out_channels + base_out_channels = block_out_channels[i] + ctrl_in_channels = ctrl_out_channels + ctrl_out_channels = ctrl_block_out_channels[i] has_crossattn = "CrossAttn" in down_block_type - add_downsample = i (3,2,1) - # into -> [{"in": 3, "out": 3}, {"in": 3, "out": 2}, {"in": 2, "out": 1}] - rev_down_block_channels = left_shifted_iterator_pairs(list(reversed(block_out_channels))) + reversed_block_out_channels = list(reversed(block_out_channels)) + out_channels = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): - has_crossattn = "CrossAttn" in up_block_type - add_upsample = i0 else in_channels + prev_output_channel = out_channels + out_channels = reversed_block_out_channels[i] + in_channels = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] ctrl_skip_channels_ = [ctrl_skip_channels.pop() for _ in range(3)] + has_crossattn = "CrossAttn" in up_block_type + is_final_block = i == len(block_out_channels) - 1 + up_blocks.append(ControlNetXSCrossAttnUpBlock2D( in_channels = in_channels, out_channels = out_channels, @@ -663,17 +749,13 @@ def left_shifted_iterator_pairs(iterable, keys=["in", "out"]): transformer_layers_per_block = rev_transformer_layers_per_block[-1], num_attention_heads = rev_num_attention_heads[-1], cross_attention_dim = rev_cross_attention_dim[-1], - add_upsample = add_upsample, + add_upsample = not is_final_block, upcast_attention = upcast_attention, )) self.down_blocks = nn.ModuleList(down_blocks) self.up_blocks = nn.ModuleList(up_blocks) - # todo umer: create control_addon.conv_in - # todo umer: create ctrl_time_embedding - # tood umer: creatae b2c for conv->down0 - @classmethod def from_unet2d( cls, @@ -1315,7 +1397,7 @@ def __init__( transformer_layers_per_block = [transformer_layers_per_block] * num_layers for i in range(num_layers): - res_skip_channels = in_channels if (i < num_layers - 1) else out_channels + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels)) @@ -1457,3 +1539,7 @@ def find_largest_factor(number, max_factor): if residual == 0: return factor factor -= 1 + + +def repeat_if_not_list(value, repetitions): + return value if isinstance(value, (tuple, list)) else [value] * repetitions From 987e4c9749477e5d5711a5712bdd72ec994cc790 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Sun, 24 Mar 2024 23:30:43 +0100 Subject: [PATCH 58/75] CheckIn --- src/diffusers/models/controlnet_xs.py | 199 +++++++++++++++++++------- 1 file changed, 149 insertions(+), 50 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index d2d03dd84d66..dcbedd663d2f 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -27,7 +27,7 @@ TimestepEmbedding, ) from .modeling_utils import ModelMixin -from .unets.unet_2d_blocks import Downsample2D, ResnetBlock2D, Transformer2DModel, UNetMidBlock2DCrossAttn, Upsample2D +from .unets.unet_2d_blocks import Downsample2D, ResnetBlock2D, Transformer2DModel, UNetMidBlock2DCrossAttn, Upsample2D, CrossAttnDownBlock2D, CrossAttnUpBlock2D from .unets.unet_2d_condition import UNet2DConditionModel @@ -221,6 +221,7 @@ def get_down_block( transformer_layers_per_block = [transformer_layers_per_block] * num_layers for i in range(num_layers): + base_in_channels = base_in_channels if i == 0 else base_out_channels ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels # Before the resnet/attention application, information is concatted from base to control. @@ -1160,34 +1161,65 @@ def __init__( @classmethod def from_modules( cls, - base_resnets: List[ResnetBlock2D], ctrl_resnets: List[ResnetBlock2D], - base_to_control_connections: List[nn.Conv2d], control_to_base_connections: List[nn.Conv2d], - base_attentions: Optional[List[Transformer2DModel]] = None, ctrl_attentions: Optional[List[Transformer2DModel]] = None, - base_downsampler: Optional[List[Transformer2DModel]] = None, ctrl_downsampler: Optional[List[Transformer2DModel]] = None,): - """todo umer""" - block = cls( - in_channels = None, - out_channels = None, - temb_channels = None, - max_norm_num_groups = 32, - has_crossattn = True, - transformer_layers_per_block = 1, - num_attention_heads = 1, - cross_attention_dim = 1024, - add_downsample = True, - upcast_attention = False, + base_downblock: CrossAttnDownBlock2D, + ctrl_downblock: nn.ModuleDict + ): + # get params + def get_first_cross_attention(block): + return block.attentions[0].transformer_blocks[0].attn2 + + base_in_channels = base_downblock.resnets[0].in_channels + base_out_channels = base_downblock.resnets[0].out_channels + ctrl_in_channels = ctrl_downblock["resnets"][0].in_channels - base_in_channels # base channels are concatted to ctrl channels in init + ctrl_out_channels = ctrl_downblock["resnets"][0].out_channels + temb_channels = base_downblock.resnets[0].time_emb_proj.in_features + num_groups = ctrl_downblock["resnets"][0].norm1.num_groups + if hasattr(base_downblock, "attentions"): + has_crossattn = True + transformer_layers_per_block = len(base_downblock.attentions) + base_num_attention_heads = get_first_cross_attention(base_downblock).heads + ctrl_num_attention_heads = get_first_cross_attention(ctrl_downblock).heads + cross_attention_dim = get_first_cross_attention(base_downblock).cross_attention_dim + upcast_attention = get_first_cross_attention(base_downblock).upcast_attention + else: + has_crossattn = False + transformer_layers_per_block = None + base_num_attention_heads = None + ctrl_num_attention_heads = None + cross_attention_dim = None + upcast_attention = None + add_downsample = base_downblock.downsamplers is not None + + # create model + model = cls( + base_in_channels = base_in_channels, + base_out_channels = base_out_channels, + ctrl_in_channels = ctrl_in_channels, + ctrl_out_channels = ctrl_out_channels, + temb_channels = temb_channels, + max_norm_num_groups = num_groups, + has_crossattn=has_crossattn, + transformer_layers_per_block = transformer_layers_per_block, + base_num_attention_heads = base_num_attention_heads, + ctrl_num_attention_heads = ctrl_num_attention_heads, + cross_attention_dim = cross_attention_dim, + add_downsample = add_downsample, + upcast_attention = upcast_attention, ) - block.base_resnets = base_resnets - block.base_attentions = base_attentions - block.ctrl_resnets = ctrl_resnets - block.ctrl_attentions = ctrl_attentions - block.b2c = base_to_control_connections - block.c2b = control_to_base_connections - block.base_downsampler = base_downsampler - block.ctrl_downsampler = ctrl_downsampler + # # load weights + model.base_resnets.load_state_dict(base_downblock.resnets.state_dict()) + model.ctrl_resnets.load_state_dict(ctrl_downblock["resnets"].state_dict()) + if has_crossattn: + model.base_attentions.load_state_dict(base_downblock.attentions.state_dict()) + model.ctrl_attentions.load_state_dict(ctrl_downblock["attentions"].state_dict()) + if add_downsample: + model.base_downsamplers.load_state_dict(base_downblock.downsamplers[0].state_dict()) + model.ctrl_downsamplers.load_state_dict(ctrl_downblock["downsamplers"].state_dict()) + model.base_to_ctrl.load_state_dict(ctrl_downblock["base_to_ctrl"].state_dict()) + model.ctrl_to_base.load_state_dict(ctrl_downblock["ctrl_to_base"].state_dict()) - return block + return model def forward( self, @@ -1320,19 +1352,46 @@ def __init__( @classmethod def from_modules( cls, - resnet: ResnetBlock2D, - attention: Optional[Transformer2DModel] = None, - upsampler: Optional[Upsample2D] = None, + base_to_ctrl: nn.Conv2d, + base_midblock: UNetMidBlock2DCrossAttn, + ctrl_midblock: UNetMidBlock2DCrossAttn, + ctrl_to_base: nn.Conv2d + ): - """Create empty subblock and set resnet, attention and upsampler manually""" - # todo umer - subblock = cls() - subblock.resnet = resnet - subblock.attention = attention - subblock.upsampler = upsampler - subblock.in_channels = resnet.in_channels - subblock.out_channels = resnet.out_channels - return subblock + # get params + def get_first_cross_attention(midblock): + return midblock.attentions[0].transformer_blocks[0].attn2 + + base_channels = ctrl_to_base.out_channels + ctrl_channels = ctrl_to_base.in_channels + transformer_layers_per_block = len(base_midblock.attentions) + temb_channels = base_midblock.resnets[0].time_emb_proj.in_features + num_groups = ctrl_midblock.resnets[0].norm1.num_groups + base_num_attention_heads = get_first_cross_attention(base_midblock).heads + ctrl_num_attention_heads = get_first_cross_attention(ctrl_midblock).heads + cross_attention_dim = get_first_cross_attention(base_midblock).cross_attention_dim + upcast_attention = get_first_cross_attention(base_midblock).upcast_attention + + # create model + model = cls( + base_channels=base_channels, + ctrl_channels=ctrl_channels, + temb_channels=temb_channels, + max_norm_num_groups = num_groups, + transformer_layers_per_block = transformer_layers_per_block, + base_num_attention_heads = base_num_attention_heads, + ctrl_num_attention_heads = ctrl_num_attention_heads, + cross_attention_dim = cross_attention_dim, + upcast_attention = upcast_attention, + ) + + # load weights + model.base_to_ctrl.load_state_dict(base_to_ctrl.state_dict()) + model.base_midblock.load_state_dict(base_midblock.state_dict()) + model.ctrl_midblock.load_state_dict(ctrl_midblock.state_dict()) + model.ctrl_to_base.load_state_dict(ctrl_to_base.state_dict()) + + return model def forward( self, @@ -1437,19 +1496,59 @@ def __init__( @classmethod def from_modules( cls, - resnet: ResnetBlock2D, - attention: Optional[Transformer2DModel] = None, - upsampler: Optional[Upsample2D] = None, + base_upblock: CrossAttnUpBlock2D, + ctrl_to_base_skip_connections: nn.ModuleList ): - """Create empty subblock and set resnet, attention and upsampler manually""" - # todo umer - subblock = cls() - subblock.resnet = resnet - subblock.attention = attention - subblock.upsampler = upsampler - subblock.in_channels = resnet.in_channels - subblock.out_channels = resnet.out_channels - return subblock + # get params + def get_first_cross_attention(block): + return block.attentions[0].transformer_blocks[0].attn2 + + base_in_channels = base_upblock.resnets[0].in_channels + base_out_channels = base_upblock.resnets[0].out_channels + ctrl_in_channels = ctrl_downblock["resnets"][0].in_channels - base_in_channels # base channels are concatted to ctrl channels in init + ctrl_out_channels = ctrl_downblock["resnets"][0].out_channels + temb_channels = base_upblock.resnets[0].time_emb_proj.in_features + num_groups = base_upblock["resnets"][0].norm1.num_groups + if hasattr(base_upblock, "attentions"): + has_crossattn = True + transformer_layers_per_block = len(base_upblock.attentions) + num_attention_heads = get_first_cross_attention(base_upblock).heads + cross_attention_dim = get_first_cross_attention(base_upblock).cross_attention_dim + upcast_attention = get_first_cross_attention(base_upblock).upcast_attention + else: + has_crossattn = False + transformer_layers_per_block = None + num_attention_heads = None + cross_attention_dim = None + upcast_attention = None + add_upsample = base_upblock.upsamplers is not None + + # create model + model = cls( + # todo umer + # in_channels: int, + # out_channels: int, + # prev_output_channel: int, + # ctrl_skip_channels: List[int], + temb_channels = temb_channels, + has_crossattn=True, + transformer_layers_per_block = transformer_layers_per_block, + num_attention_heads = num_attention_heads, + cross_attention_dim = cross_attention_dim, + add_upsample = add_upsample, + upcast_attention = upcast_attention + ) + + # load weights + model.resnets.load_state_dict(base_upblock.resnets.state_dict()) + if has_crossattn: + model.attentions.load_state_dict(base_upblock.attentions.state_dict()) + if add_upsample: + model.upsamplers.load_state_dict(base_upblock.upsamplers[0].state_dict()) + model.ctrl_to_base.load_state_dict(ctrl_to_base_skip_connections.state_dict()) + + return model + def forward( self, From 890542b057d5562a65935ae2be00b31ac05e7fa9 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Mon, 25 Mar 2024 17:19:20 +0100 Subject: [PATCH 59/75] Made from_modules, from_unet and no_control work --- src/diffusers/models/controlnet_xs.py | 653 +++++++++++++------------- 1 file changed, 339 insertions(+), 314 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index dcbedd663d2f..999d567ee1b4 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -27,7 +27,15 @@ TimestepEmbedding, ) from .modeling_utils import ModelMixin -from .unets.unet_2d_blocks import Downsample2D, ResnetBlock2D, Transformer2DModel, UNetMidBlock2DCrossAttn, Upsample2D, CrossAttnDownBlock2D, CrossAttnUpBlock2D +from .unets.unet_2d_blocks import ( + CrossAttnDownBlock2D, + CrossAttnUpBlock2D, + Downsample2D, + ResnetBlock2D, + Transformer2DModel, + UNetMidBlock2DCrossAttn, + Upsample2D, +) from .unets.unet_2d_condition import UNet2DConditionModel @@ -143,58 +151,6 @@ class ControlNetXSAddon(ModelMixin, ConfigMixin): Maximum number of groups in group normal. The actual number will the the largest divisor of the respective channels, that is <= max_norm_num_groups. """ - @staticmethod - def gather_base_subblock_sizes(blocks_sizes: List[int]): - """ - To create a correctly sized `ControlNetXSAddon`, we need to know - the channels sizes of each base subblock. - - Parameters: - blocks_sizes (`List[int]`): - Channel sizes of each base block. - """ - - n_blocks = len(blocks_sizes) - n_subblocks_per_block = 3 - - down_out = [] - up_in = [] - - # down_out - for b in range(n_blocks): - for i in range(n_subblocks_per_block): - if b == n_blocks - 1 and i == 2: - # Last block has no downsampler, so there are only 2 subblocks instead of 3 - continue - - # The input channels are changed by the first resnet, which is in the first subblock. - if i == 0: - # Same input channels - down_out.append(blocks_sizes[max(b - 1, 0)]) - else: - # Changed input channels - down_out.append(blocks_sizes[b]) - - down_out.append(blocks_sizes[-1]) - - # up_in - rev_blocks_sizes = list(reversed(blocks_sizes)) - for b in range(len(rev_blocks_sizes)): - for i in range(n_subblocks_per_block): - # The input channels are changed by the first resnet, which is in the first subblock. - if i == 0: - # Same input channels - up_in.append(rev_blocks_sizes[max(b - 1, 0)]) - else: - # Changed input channels - up_in.append(rev_blocks_sizes[b]) - - return { - "down - out": down_out, - "mid - out": blocks_sizes[-1], - "up - in": up_in, - } - @staticmethod def get_down_block( base_in_channels: int, @@ -210,7 +166,7 @@ def get_down_block( add_downsample: bool = True, upcast_attention: Optional[bool] = False, ): - num_layers = 2 # only support sd + sdxl + num_layers = 2 # only support sd + sdxl resnets = [] attentions = [] @@ -230,7 +186,7 @@ def get_down_block( resnets.append( ResnetBlock2D( - in_channels=ctrl_in_channels + base_in_channels, # information from base is concatted to ctrl + in_channels=ctrl_in_channels + base_in_channels, # information from base is concatted to ctrl out_channels=ctrl_out_channels, temb_channels=temb_channels, groups=find_largest_factor(ctrl_in_channels + base_in_channels, max_factor=max_norm_num_groups), @@ -262,7 +218,9 @@ def get_down_block( # Concat doesn't require change in number of channels base_to_ctrl.append(make_zero_conv(base_out_channels, base_out_channels)) - downsamplers = Downsample2D(ctrl_out_channels + base_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op") + downsamplers = Downsample2D( + ctrl_out_channels + base_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op" + ) # After the downsampler application, information is added from control to base # Addition requires change in number of channels @@ -270,11 +228,13 @@ def get_down_block( else: downsamplers = None - module_dict = nn.ModuleDict({ - "resnets": nn.ModuleList(resnets), - "base_to_ctrl": nn.ModuleList(base_to_ctrl), - "ctrl_to_base": nn.ModuleList(ctrl_to_base), - }) + module_dict = nn.ModuleDict( + { + "resnets": nn.ModuleList(resnets), + "base_to_ctrl": nn.ModuleList(base_to_ctrl), + "ctrl_to_base": nn.ModuleList(ctrl_to_base), + } + ) if has_crossattn: module_dict["attentions"] = nn.ModuleList(attentions) if downsamplers is not None: @@ -292,7 +252,7 @@ def get_mid_block( num_attention_heads: Optional[int] = 1, cross_attention_dim: Optional[int] = 1024, upcast_attention: bool = False, - ): + ): # Before the midblock application, information is concatted from base to control. # Concat doesn't require change in number of channels base_to_ctrl = make_zero_conv(base_channels, base_channels) @@ -307,27 +267,23 @@ def get_mid_block( cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, use_linear_projection=True, - upcast_attention=upcast_attention + upcast_attention=upcast_attention, ) # After the midblock application, information is added from control to base # Addition requires change in number of channels ctrl_to_base = make_zero_conv(ctrl_channels, base_channels) - return nn.ModuleDict({ - "base_to_ctrl": base_to_ctrl, - "midblock": midblock, - "ctrl_to_base": ctrl_to_base - }) + return nn.ModuleDict({"base_to_ctrl": base_to_ctrl, "midblock": midblock, "ctrl_to_base": ctrl_to_base}) @staticmethod def get_up_connections( out_channels: int, prev_output_channel: int, ctrl_skip_channels: List[int], - ): + ): ctrl_to_base = [] - num_layers = 3 # only support sd + sdxl + num_layers = 3 # only support sd + sdxl for i in range(num_layers): resnet_in_channels = prev_output_channel if i == 0 else out_channels ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels)) @@ -337,15 +293,18 @@ def get_up_connections( @classmethod def from_unet( cls, - base_model: UNet2DConditionModel, + unet: UNet2DConditionModel, size_ratio: Optional[float] = None, block_out_channels: Optional[List[int]] = None, num_attention_heads: Optional[List[int]] = None, learn_time_embedding: bool = False, time_embedding_mix: int = 1.0, + conditioning_channels: int = 3, + conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), ): r""" + todo umer Instantiate a [`ControlNetXSAddon`] from a [`UNet2DConditionModel`]. Parameters: @@ -372,34 +331,35 @@ def from_unet( "Pass exactly one of `block_out_channels` (for absolute sizing) or `size_ratio` (for relative sizing)." ) - channels_base = ControlNetXSAddon.gather_base_subblock_sizes(base_model.config.block_out_channels) - - block_out_channels = [int(b * size_ratio) for b in base_model.config.block_out_channels] + # Create model + block_out_channels = block_out_channels or [int(b * size_ratio) for b in unet.config.block_out_channels] if num_attention_heads is None: # The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. - num_attention_heads = base_model.config.attention_head_dim + num_attention_heads = unet.config.attention_head_dim - max_norm_num_groups = base_model.config.norm_num_groups + model = cls( + conditioning_channels = conditioning_channels, + conditioning_channel_order = conditioning_channel_order, + conditioning_embedding_out_channels = conditioning_embedding_out_channels, + time_embedding_input_dim = unet.time_embedding.linear_1.in_features, + time_embedding_dim = unet.time_embedding.linear_1.out_features, + time_embedding_mix = time_embedding_mix, + learn_time_embedding = learn_time_embedding, + attention_head_dim = num_attention_heads, + block_out_channels = block_out_channels, + base_block_out_channels = unet.config.block_out_channels, + cross_attention_dim = unet.config.cross_attention_dim, + down_block_types = unet.config.down_block_types, + sample_size = unet.config.sample_size, + transformer_layers_per_block = unet.config.transformer_layers_per_block, + upcast_attention = unet.config.upcast_attention, + max_norm_num_groups = unet.config.norm_num_groups, + ) - time_embedding_input_dim = base_model.time_embedding.linear_1.in_features - time_embedding_dim = base_model.time_embedding.linear_1.out_features + # ensure that the ControlNetXSAddon is the same dtype as the UNet2DConditionModel + model.to(unet.dtype) - return ControlNetXSAddon( - learn_time_embedding=learn_time_embedding, - channels_base=channels_base, - attention_head_dim=num_attention_heads, - block_out_channels=block_out_channels, - cross_attention_dim=base_model.config.cross_attention_dim, - down_block_types=base_model.config.down_block_types, - sample_size=base_model.config.sample_size, - transformer_layers_per_block=base_model.config.transformer_layers_per_block, - upcast_attention=base_model.config.upcast_attention, - max_norm_num_groups=max_norm_num_groups, - conditioning_embedding_out_channels=conditioning_embedding_out_channels, - time_embedding_input_dim=time_embedding_input_dim, - time_embedding_dim=time_embedding_dim, - time_embedding_mix=time_embedding_mix, - ) + return model @register_to_config def __init__( @@ -446,9 +406,13 @@ def __init__( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) - transformer_layers_per_block = repeat_if_not_list(transformer_layers_per_block, repetitions=len(down_block_types)) + transformer_layers_per_block = repeat_if_not_list( + transformer_layers_per_block, repetitions=len(down_block_types) + ) cross_attention_dim = repeat_if_not_list(cross_attention_dim, repetitions=len(down_block_types)) - num_attention_heads = repeat_if_not_list(num_attention_heads, repetitions=len(down_block_types)) # todo umer: im using # attn heads & dim attn heads. should only be one. + num_attention_heads = repeat_if_not_list( + num_attention_heads, repetitions=len(down_block_types) + ) # todo umer: im using # attn heads & dim attn heads. should only be one. attention_head_dim = repeat_if_not_list(attention_head_dim, repetitions=len(down_block_types)) if len(num_attention_heads) != len(down_block_types): @@ -492,39 +456,43 @@ def __init__( ctrl_in_channels = ctrl_out_channels ctrl_out_channels = block_out_channels[i] has_crossattn = "CrossAttn" in down_block_type - is_final_block = i==len(down_block_types)-1 - - self.down_blocks.append(ControlNetXSAddon.get_down_block( - base_in_channels = base_in_channels, - base_out_channels = base_out_channels, - ctrl_in_channels = ctrl_in_channels, - ctrl_out_channels = ctrl_out_channels, - temb_channels = time_embedding_dim, - max_norm_num_groups = max_norm_num_groups, - has_crossattn = has_crossattn, - transformer_layers_per_block = transformer_layers_per_block[i], - num_attention_heads = attention_head_dim[i], - cross_attention_dim = cross_attention_dim[i], - add_downsample = not is_final_block, - upcast_attention = upcast_attention - )) + is_final_block = i == len(down_block_types) - 1 + + self.down_blocks.append( + ControlNetXSAddon.get_down_block( + base_in_channels=base_in_channels, + base_out_channels=base_out_channels, + ctrl_in_channels=ctrl_in_channels, + ctrl_out_channels=ctrl_out_channels, + temb_channels=time_embedding_dim, + max_norm_num_groups=max_norm_num_groups, + has_crossattn=has_crossattn, + transformer_layers_per_block=transformer_layers_per_block[i], + num_attention_heads=attention_head_dim[i], + cross_attention_dim=cross_attention_dim[i], + add_downsample=not is_final_block, + upcast_attention=upcast_attention, + ) + ) # mid self.mid_block = ControlNetXSAddon.get_mid_block( base_channels=base_block_out_channels[-1], ctrl_channels=block_out_channels[-1], - temb_channels = time_embedding_dim, - transformer_layers_per_block = transformer_layers_per_block[-1], - num_attention_heads = attention_head_dim[-1], - cross_attention_dim = cross_attention_dim[-1], - upcast_attention = upcast_attention, + temb_channels=time_embedding_dim, + transformer_layers_per_block=transformer_layers_per_block[-1], + num_attention_heads=attention_head_dim[-1], + cross_attention_dim=cross_attention_dim[-1], + upcast_attention=upcast_attention, ) # up # The skip connection channels are the output of the conv_in and of all the down subblocks ctrl_skip_channels = [block_out_channels[0]] for i, out_channels in enumerate(block_out_channels): - number_of_subblocks = 3 if i base + if do_control: + h_base = h_base + self.control_to_base_for_conv_in(h_ctrl) * conditioning_scale # add ctrl -> base hs_base.append(h_base) hs_ctrl.append(h_ctrl) for down in self.down_blocks: - h_base,h_ctrl,residual_hb,residual_hc = down(h_base,h_ctrl, temb, cemb, conditioning_scale, cross_attention_kwargs, attention_mask) + h_base, h_ctrl, residual_hb, residual_hc = down( + hidden_states_base=h_base, + hidden_states_ctrl=h_ctrl, + temb=temb, + encoder_hidden_states=cemb, + conditioning_scale=conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + do_control=do_control + ) hs_base.extend(residual_hb) hs_ctrl.extend(residual_hc) # 2 - mid - h_base,h_ctrl = self.mid_block(h_base,h_ctrl, temb, cemb, conditioning_scale, cross_attention_kwargs, attention_mask) + h_base, h_ctrl = self.mid_block( + hidden_states_base=h_base, + hidden_states_ctrl=h_ctrl, + temb=temb, + encoder_hidden_states=cemb, + conditioning_scale=conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + do_control=do_control + ) # 3 - up for up in self.up_blocks: @@ -1036,7 +1035,17 @@ def forward( skips_hc = hs_ctrl[-n_resnets:] hs_base = hs_base[:-n_resnets] hs_ctrl = hs_ctrl[:-n_resnets] - h_base = up(h_base,skips_hb,skips_hc,temb, cemb, conditioning_scale, cross_attention_kwargs, attention_mask) + h_base = up( + hidden_states=h_base, + res_hidden_states_tuple_base= skips_hb, + res_hidden_states_tuple_ctrl=skips_hc, + temb= temb, + encoder_hidden_states=cemb, + conditioning_scale= conditioning_scale, + cross_attention_kwargs= cross_attention_kwargs, + attention_mask= attention_mask, + do_control=do_control + ) # 4 - conv out h_base = self.base_conv_norm_out(h_base) @@ -1069,12 +1078,12 @@ def __init__( super().__init__() base_resnets = [] base_attentions = [] - ctrl_resnets =[] + ctrl_resnets = [] ctrl_attentions = [] ctrl_to_base = [] base_to_ctrl = [] - num_layers = 2 # only support sd + sdxl + num_layers = 2 # only support sd + sdxl if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * num_layers @@ -1096,7 +1105,7 @@ def __init__( ) ctrl_resnets.append( ResnetBlock2D( - in_channels=ctrl_in_channels + base_in_channels, # information from base is concatted to ctrl + in_channels=ctrl_in_channels + base_in_channels, # information from base is concatted to ctrl out_channels=ctrl_out_channels, temb_channels=temb_channels, groups=find_largest_factor(ctrl_in_channels + base_in_channels, max_factor=max_norm_num_groups), @@ -1108,13 +1117,13 @@ def __init__( if has_crossattn: base_attentions.append( Transformer2DModel( - base_num_attention_heads, - base_out_channels // base_num_attention_heads, - in_channels=base_out_channels, - num_layers=transformer_layers_per_block[i], - cross_attention_dim=cross_attention_dim, - use_linear_projection=True, - upcast_attention=upcast_attention, + base_num_attention_heads, + base_out_channels // base_num_attention_heads, + in_channels=base_out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + use_linear_projection=True, + upcast_attention=upcast_attention, ) ) ctrl_attentions.append( @@ -1139,8 +1148,12 @@ def __init__( # Concat doesn't require change in number of channels base_to_ctrl.append(make_zero_conv(base_out_channels, base_out_channels)) - self.base_downsamplers = Downsample2D(base_out_channels, use_conv=True, out_channels=base_out_channels, name="op") - self.ctrl_downsamplers = Downsample2D(ctrl_out_channels + base_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op") + self.base_downsamplers = Downsample2D( + base_out_channels, use_conv=True, out_channels=base_out_channels, name="op" + ) + self.ctrl_downsamplers = Downsample2D( + ctrl_out_channels + base_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op" + ) # After the downsampler application, information is added from control to base # Addition requires change in number of channels @@ -1151,32 +1164,30 @@ def __init__( self.base_resnets = nn.ModuleList(base_resnets) self.ctrl_resnets = nn.ModuleList(ctrl_resnets) - self.base_attentions = nn.ModuleList(base_attentions) if has_crossattn else [None]*num_layers - self.ctrl_attentions = nn.ModuleList(ctrl_attentions) if has_crossattn else [None]*num_layers + self.base_attentions = nn.ModuleList(base_attentions) if has_crossattn else [None] * num_layers + self.ctrl_attentions = nn.ModuleList(ctrl_attentions) if has_crossattn else [None] * num_layers self.base_to_ctrl = nn.ModuleList(base_to_ctrl) self.ctrl_to_base = nn.ModuleList(ctrl_to_base) self.gradient_checkpointing = False @classmethod - def from_modules( - cls, - base_downblock: CrossAttnDownBlock2D, - ctrl_downblock: nn.ModuleDict - ): + def from_modules(cls, base_downblock: CrossAttnDownBlock2D, ctrl_downblock: nn.ModuleDict): # get params def get_first_cross_attention(block): - return block.attentions[0].transformer_blocks[0].attn2 + return block.attentions[0].transformer_blocks[0].attn2 base_in_channels = base_downblock.resnets[0].in_channels base_out_channels = base_downblock.resnets[0].out_channels - ctrl_in_channels = ctrl_downblock["resnets"][0].in_channels - base_in_channels # base channels are concatted to ctrl channels in init + ctrl_in_channels = ( + ctrl_downblock["resnets"][0].in_channels - base_in_channels + ) # base channels are concatted to ctrl channels in init ctrl_out_channels = ctrl_downblock["resnets"][0].out_channels temb_channels = base_downblock.resnets[0].time_emb_proj.in_features num_groups = ctrl_downblock["resnets"][0].norm1.num_groups if hasattr(base_downblock, "attentions"): has_crossattn = True - transformer_layers_per_block = len(base_downblock.attentions) + transformer_layers_per_block = len(base_downblock.attentions[0].transformer_blocks) base_num_attention_heads = get_first_cross_attention(base_downblock).heads ctrl_num_attention_heads = get_first_cross_attention(ctrl_downblock).heads cross_attention_dim = get_first_cross_attention(base_downblock).cross_attention_dim @@ -1192,19 +1203,19 @@ def get_first_cross_attention(block): # create model model = cls( - base_in_channels = base_in_channels, - base_out_channels = base_out_channels, - ctrl_in_channels = ctrl_in_channels, - ctrl_out_channels = ctrl_out_channels, - temb_channels = temb_channels, - max_norm_num_groups = num_groups, + base_in_channels=base_in_channels, + base_out_channels=base_out_channels, + ctrl_in_channels=ctrl_in_channels, + ctrl_out_channels=ctrl_out_channels, + temb_channels=temb_channels, + max_norm_num_groups=num_groups, has_crossattn=has_crossattn, - transformer_layers_per_block = transformer_layers_per_block, - base_num_attention_heads = base_num_attention_heads, - ctrl_num_attention_heads = ctrl_num_attention_heads, - cross_attention_dim = cross_attention_dim, - add_downsample = add_downsample, - upcast_attention = upcast_attention, + transformer_layers_per_block=transformer_layers_per_block, + base_num_attention_heads=base_num_attention_heads, + ctrl_num_attention_heads=ctrl_num_attention_heads, + cross_attention_dim=cross_attention_dim, + add_downsample=add_downsample, + upcast_attention=upcast_attention, ) # # load weights @@ -1224,14 +1235,15 @@ def get_first_cross_attention(block): def forward( self, hidden_states_base: torch.FloatTensor, - hidden_states_ctrl: torch.FloatTensor, temb: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, + hidden_states_ctrl: Optional[torch.FloatTensor] = None, conditioning_scale: Optional[float] = 1.0, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: # todo umer: output type hint correct? + do_control: bool = True, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: # todo umer: output type hint correct? if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") @@ -1245,12 +1257,15 @@ def forward( base_blocks = list(zip(self.base_resnets, self.base_attentions)) ctrl_blocks = list(zip(self.ctrl_resnets, self.ctrl_attentions)) - for (b_res, b_attn), (c_res, c_attn), b2c, c2b in zip(base_blocks, ctrl_blocks, self.base_to_ctrl, self.ctrl_to_base): + for (b_res, b_attn), (c_res, c_attn), b2c, c2b in zip( + base_blocks, ctrl_blocks, self.base_to_ctrl, self.ctrl_to_base + ): if self.training and self.gradient_checkpointing: raise NotImplementedError("todo umer") else: # concat base -> ctrl - h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) + if do_control: + h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) # apply base subblock h_base = b_res(h_base, temb) @@ -1265,19 +1280,21 @@ def forward( )[0] # apply ctrl subblock - h_ctrl = c_res(h_ctrl, temb) - if c_attn is not None: - h_ctrl = c_attn( - h_ctrl, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] + if do_control: + h_ctrl = c_res(h_ctrl, temb) + if c_attn is not None: + h_ctrl = c_attn( + h_ctrl, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] # add ctrl -> base - h_base = h_base + c2b(h_ctrl) * conditioning_scale + if do_control: + h_base = h_base + c2b(h_ctrl) * conditioning_scale base_output_states = base_output_states + (h_base,) ctrl_output_states = ctrl_output_states + (h_ctrl,) @@ -1287,18 +1304,21 @@ def forward( c2b = self.ctrl_to_base[-1] # concat base -> ctrl - h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) + if do_control: + h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) # apply base subblock h_base = self.base_downsamplers(h_base) # apply ctrl subblock - h_ctrl = self.ctrl_downsamplers(h_ctrl) + if do_control: + h_ctrl = self.ctrl_downsamplers(h_ctrl) # add ctrl -> base - h_base = h_base + c2b(h_ctrl) * conditioning_scale + if do_control: + h_base = h_base + c2b(h_ctrl) * conditioning_scale base_output_states = base_output_states + (h_base,) ctrl_output_states = ctrl_output_states + (h_ctrl,) - return h_base, h_ctrl,base_output_states, ctrl_output_states + return h_base, h_ctrl, base_output_states, ctrl_output_states class ControlNetXSCrossAttnMidBlock2D(nn.Module): @@ -1327,7 +1347,7 @@ def __init__( cross_attention_dim=cross_attention_dim, num_attention_heads=base_num_attention_heads, use_linear_projection=True, - upcast_attention=upcast_attention + upcast_attention=upcast_attention, ) self.ctrl_midblock = UNetMidBlock2DCrossAttn( @@ -1340,7 +1360,7 @@ def __init__( cross_attention_dim=cross_attention_dim, num_attention_heads=ctrl_num_attention_heads, use_linear_projection=True, - upcast_attention=upcast_attention + upcast_attention=upcast_attention, ) # After the midblock application, information is added from control to base @@ -1352,19 +1372,20 @@ def __init__( @classmethod def from_modules( cls, - base_to_ctrl: nn.Conv2d, base_midblock: UNetMidBlock2DCrossAttn, - ctrl_midblock: UNetMidBlock2DCrossAttn, - ctrl_to_base: nn.Conv2d - + ctrl_midblock_dict: nn.ModuleDict, ): + base_to_ctrl = ctrl_midblock_dict["base_to_ctrl"] + ctrl_to_base = ctrl_midblock_dict["ctrl_to_base"] + ctrl_midblock = ctrl_midblock_dict["midblock"] + # get params def get_first_cross_attention(midblock): return midblock.attentions[0].transformer_blocks[0].attn2 base_channels = ctrl_to_base.out_channels - ctrl_channels = ctrl_to_base.in_channels - transformer_layers_per_block = len(base_midblock.attentions) + ctrl_channels = ctrl_to_base.in_channels + transformer_layers_per_block = len(base_midblock.attentions[0].transformer_blocks) temb_channels = base_midblock.resnets[0].time_emb_proj.in_features num_groups = ctrl_midblock.resnets[0].norm1.num_groups base_num_attention_heads = get_first_cross_attention(base_midblock).heads @@ -1377,12 +1398,12 @@ def get_first_cross_attention(midblock): base_channels=base_channels, ctrl_channels=ctrl_channels, temb_channels=temb_channels, - max_norm_num_groups = num_groups, - transformer_layers_per_block = transformer_layers_per_block, - base_num_attention_heads = base_num_attention_heads, - ctrl_num_attention_heads = ctrl_num_attention_heads, - cross_attention_dim = cross_attention_dim, - upcast_attention = upcast_attention, + max_norm_num_groups=num_groups, + transformer_layers_per_block=transformer_layers_per_block, + base_num_attention_heads=base_num_attention_heads, + ctrl_num_attention_heads=ctrl_num_attention_heads, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, ) # load weights @@ -1396,14 +1417,15 @@ def get_first_cross_attention(midblock): def forward( self, hidden_states_base: torch.FloatTensor, - hidden_states_ctrl: torch.FloatTensor, temb: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, + encoder_hidden_states: torch.FloatTensor, + hidden_states_ctrl: Optional[torch.FloatTensor] = None, conditioning_scale: Optional[float] = 1.0, cross_attention_kwargs: Optional[Dict[str, Any]] = None, attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - ) -> torch.FloatTensor: # todo umer: output type hint correct? + do_control: bool = True, + ) -> torch.FloatTensor: # todo umer: output type hint correct? if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") @@ -1419,10 +1441,12 @@ def forward( "encoder_attention_mask": encoder_attention_mask, } - h_ctrl = torch.cat([h_ctrl, self.base_to_ctrl(h_base)], dim=1) # concat base -> ctrl + if do_control: + h_ctrl = torch.cat([h_ctrl, self.base_to_ctrl(h_base)], dim=1) # concat base -> ctrl h_base = self.base_midblock(h_base, **joint_args) # apply base mid block - h_ctrl = self.ctrl_midblock(h_ctrl, **joint_args) # apply ctrl mid block - h_base = h_base + self.ctrl_to_base(h_ctrl) * conditioning_scale # add ctrl -> base + if do_control: + h_ctrl = self.ctrl_midblock(h_ctrl, **joint_args) # apply ctrl mid block + h_base = h_base + self.ctrl_to_base(h_ctrl) * conditioning_scale # add ctrl -> base return h_base, h_ctrl @@ -1447,7 +1471,7 @@ def __init__( attentions = [] ctrl_to_base = [] - num_layers = 3 # only support sd + sdxl + num_layers = 3 # only support sd + sdxl self.has_cross_attention = has_crossattn self.num_attention_heads = num_attention_heads @@ -1483,7 +1507,7 @@ def __init__( ) self.resnets = nn.ModuleList(resnets) - self.attentions = nn.ModuleList(attentions) if has_crossattn else [None]*num_layers + self.attentions = nn.ModuleList(attentions) if has_crossattn else [None] * num_layers self.ctrl_to_base = nn.ModuleList(ctrl_to_base) if add_upsample: @@ -1494,24 +1518,19 @@ def __init__( self.gradient_checkpointing = False @classmethod - def from_modules( - cls, - base_upblock: CrossAttnUpBlock2D, - ctrl_to_base_skip_connections: nn.ModuleList - ): + def from_modules(cls, base_upblock: CrossAttnUpBlock2D, ctrl_to_base_skip_connections: nn.ModuleList): # get params def get_first_cross_attention(block): - return block.attentions[0].transformer_blocks[0].attn2 + return block.attentions[0].transformer_blocks[0].attn2 - base_in_channels = base_upblock.resnets[0].in_channels - base_out_channels = base_upblock.resnets[0].out_channels - ctrl_in_channels = ctrl_downblock["resnets"][0].in_channels - base_in_channels # base channels are concatted to ctrl channels in init - ctrl_out_channels = ctrl_downblock["resnets"][0].out_channels + out_channels = base_upblock.resnets[0].out_channels + in_channels = base_upblock.resnets[-1].in_channels - out_channels + prev_output_channels = base_upblock.resnets[0].in_channels - out_channels + ctrl_skip_channelss = [c.in_channels for c in ctrl_to_base_skip_connections] temb_channels = base_upblock.resnets[0].time_emb_proj.in_features - num_groups = base_upblock["resnets"][0].norm1.num_groups if hasattr(base_upblock, "attentions"): has_crossattn = True - transformer_layers_per_block = len(base_upblock.attentions) + transformer_layers_per_block = len(base_upblock.attentions[0].transformer_blocks) num_attention_heads = get_first_cross_attention(base_upblock).heads cross_attention_dim = get_first_cross_attention(base_upblock).cross_attention_dim upcast_attention = get_first_cross_attention(base_upblock).upcast_attention @@ -1525,18 +1544,17 @@ def get_first_cross_attention(block): # create model model = cls( - # todo umer - # in_channels: int, - # out_channels: int, - # prev_output_channel: int, - # ctrl_skip_channels: List[int], - temb_channels = temb_channels, - has_crossattn=True, - transformer_layers_per_block = transformer_layers_per_block, - num_attention_heads = num_attention_heads, - cross_attention_dim = cross_attention_dim, - add_upsample = add_upsample, - upcast_attention = upcast_attention + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channels, + ctrl_skip_channels=ctrl_skip_channelss, + temb_channels=temb_channels, + has_crossattn=has_crossattn, + transformer_layers_per_block=transformer_layers_per_block, + num_attention_heads=num_attention_heads, + cross_attention_dim=cross_attention_dim, + add_upsample=add_upsample, + upcast_attention=upcast_attention, ) # load weights @@ -1549,12 +1567,11 @@ def get_first_cross_attention(block): return model - def forward( self, hidden_states: torch.FloatTensor, - res_hidden_states_tuple_base: Tuple[torch.FloatTensor, ...], # todo umer: why ... in type hint? - res_hidden_states_tuple_cltr: Tuple[torch.FloatTensor, ...], # todo umer: why ... in type hint? + res_hidden_states_tuple_base: Tuple[torch.FloatTensor, ...], # todo umer: why ... in type hint? + res_hidden_states_tuple_ctrl: Tuple[torch.FloatTensor, ...], # todo umer: why ... in type hint? temb: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, conditioning_scale: Optional[float] = 1.0, @@ -1562,7 +1579,8 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, upsample_size: Optional[int] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - ) -> torch.FloatTensor: # todo umer: output type hint correct? + do_control: bool = True, + ) -> torch.FloatTensor: # todo umer: output type hint correct? if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") @@ -1578,8 +1596,15 @@ def forward( resnet_with_upsampler = self.resnets[-1] attn_with_upsampler = self.attentions[-1] - for resnet, attn, c2b, res_h_base, res_h_ctrl in zip(resnets_without_upsampler, attn_without_upsampler, self.ctrl_to_base, reversed(res_hidden_states_tuple_base), reversed(res_hidden_states_tuple_cltr)): - hidden_states += c2b(res_h_ctrl) * conditioning_scale + for resnet, attn, c2b, res_h_base, res_h_ctrl in zip( + resnets_without_upsampler, + attn_without_upsampler, + self.ctrl_to_base, + reversed(res_hidden_states_tuple_base), + reversed(res_hidden_states_tuple_ctrl), + ): + if do_control: + hidden_states += c2b(res_h_ctrl) * conditioning_scale hidden_states = torch.cat([hidden_states, res_h_base], dim=1) if self.training and self.gradient_checkpointing: @@ -1599,9 +1624,9 @@ def forward( if self.upsamplers is not None: c2b = self.ctrl_to_base[-1] res_h_base = res_hidden_states_tuple_base[0] - res_h_ctrl = res_hidden_states_tuple_cltr[0] - - hidden_states += c2b(res_h_ctrl) * conditioning_scale + res_h_ctrl = res_hidden_states_tuple_ctrl[0] + if do_control: + hidden_states += c2b(res_h_ctrl) * conditioning_scale hidden_states = torch.cat([hidden_states, res_h_base], dim=1) hidden_states = resnet_with_upsampler(hidden_states, temb) From a272300ec29775eec8d4faa79cb0993ff7d5434c Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Mon, 25 Mar 2024 22:57:52 +0100 Subject: [PATCH 60/75] make style,quality,fix-copies & small changes --- src/diffusers/models/controlnet_xs.py | 85 ++++++++++--------- .../controlnet_xs/pipeline_controlnet_xs.py | 2 +- .../pipeline_controlnet_xs_sd_xl.py | 2 +- 3 files changed, 45 insertions(+), 44 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 999d567ee1b4..ebc65b1ee3fa 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -119,10 +119,6 @@ class ControlNetXSAddon(ModelMixin, ConfigMixin): The channel order of conditional image. Will convert to `rgb` if it's `bgr`. conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): The tuple of output channels for each block in the `controlnet_cond_embedding` layer. - time_embedding_input_dim (`int`, defaults to 320): - Dimension of input into time embedding. Needs to be same as in the base model. - time_embedding_dim (`int`, defaults to 1280): - Dimension of output from time embedding. Needs to be same as in the base model. time_embedding_mix (`float`, defaults to 1.0): If 0, then only the control addon's time embedding is used. If 1, then only the base unet's time embedding is used. @@ -338,22 +334,20 @@ def from_unet( num_attention_heads = unet.config.attention_head_dim model = cls( - conditioning_channels = conditioning_channels, - conditioning_channel_order = conditioning_channel_order, - conditioning_embedding_out_channels = conditioning_embedding_out_channels, - time_embedding_input_dim = unet.time_embedding.linear_1.in_features, - time_embedding_dim = unet.time_embedding.linear_1.out_features, - time_embedding_mix = time_embedding_mix, - learn_time_embedding = learn_time_embedding, - attention_head_dim = num_attention_heads, - block_out_channels = block_out_channels, - base_block_out_channels = unet.config.block_out_channels, - cross_attention_dim = unet.config.cross_attention_dim, - down_block_types = unet.config.down_block_types, - sample_size = unet.config.sample_size, - transformer_layers_per_block = unet.config.transformer_layers_per_block, - upcast_attention = unet.config.upcast_attention, - max_norm_num_groups = unet.config.norm_num_groups, + conditioning_channels=conditioning_channels, + conditioning_channel_order=conditioning_channel_order, + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + time_embedding_mix=time_embedding_mix, + learn_time_embedding=learn_time_embedding, + attention_head_dim=num_attention_heads, + block_out_channels=block_out_channels, + base_block_out_channels=unet.config.block_out_channels, + cross_attention_dim=unet.config.cross_attention_dim, + down_block_types=unet.config.down_block_types, + sample_size=unet.config.sample_size, + transformer_layers_per_block=unet.config.transformer_layers_per_block, + upcast_attention=unet.config.upcast_attention, + max_norm_num_groups=unet.config.norm_num_groups, ) # ensure that the ControlNetXSAddon is the same dtype as the UNet2DConditionModel @@ -367,8 +361,6 @@ def __init__( conditioning_channels: int = 3, conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), - time_embedding_input_dim: Optional[int] = 320, - time_embedding_dim: Optional[int] = 1280, time_embedding_mix: float = 1.0, learn_time_embedding: bool = False, attention_head_dim: Union[int, Tuple[int]] = 4, @@ -390,6 +382,9 @@ def __init__( self.sample_size = sample_size + time_embedding_input_dim = base_block_out_channels[0] + time_embedding_dim = base_block_out_channels[0] * 4 + # `num_attention_heads` defaults to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 @@ -738,7 +733,7 @@ def __init__( def from_unet( cls, unet: UNet2DConditionModel, - controlnet: Optional[ControlNetXSAddon] = None, + controlnet: Optional[ControlNetXSAddon] = None, size_ratio: Optional[float] = None, ctrl_block_out_channels: Optional[List[float]] = None, time_embedding_mix: Optional[float] = 1.0, @@ -787,22 +782,22 @@ def from_unet( # # load weights # from unet modules_from_unet = [ - 'time_embedding', - 'conv_in', - 'conv_norm_out', - 'conv_out', + "time_embedding", + "conv_in", + "conv_norm_out", + "conv_out", ] for m in modules_from_unet: - getattr(model, 'base_' + m).load_state_dict(getattr(unet, m).state_dict()) + getattr(model, "base_" + m).load_state_dict(getattr(unet, m).state_dict()) optional_modules_from_unet = [ - 'class_embedding', - 'add_time_proj', - 'add_embedding', + "class_embedding", + "add_time_proj", + "add_embedding", ] for m in optional_modules_from_unet: if hasattr(unet, m) and getattr(unet, m) is not None: - getattr(model, 'base_' + m).load_state_dict(getattr(unet, m).state_dict()) + getattr(model, "base_" + m).load_state_dict(getattr(unet, m).state_dict()) # from controlnet model.controlnet_cond_embedding.load_state_dict(controlnet.controlnet_cond_embedding.state_dict()) @@ -812,9 +807,15 @@ def from_unet( model.control_to_base_for_conv_in.load_state_dict(controlnet.control_to_base_for_conv_in.state_dict()) # from both - model.down_blocks = nn.ModuleList(ControlNetXSCrossAttnDownBlock2D.from_modules(b,c) for b,c in zip(unet.down_blocks, controlnet.down_blocks)) + model.down_blocks = nn.ModuleList( + ControlNetXSCrossAttnDownBlock2D.from_modules(b, c) + for b, c in zip(unet.down_blocks, controlnet.down_blocks) + ) model.mid_block = ControlNetXSCrossAttnMidBlock2D.from_modules(unet.mid_block, controlnet.mid_block) - model.up_blocks = nn.ModuleList(ControlNetXSCrossAttnUpBlock2D.from_modules(b,c) for b,c in zip(unet.up_blocks, controlnet.up_connections)) + model.up_blocks = nn.ModuleList( + ControlNetXSCrossAttnUpBlock2D.from_modules(b, c) + for b, c in zip(unet.up_blocks, controlnet.up_connections) + ) # ensure that the UNetControlNetXSModel is the same dtype as the UNet2DConditionModel model.to(unet.dtype) @@ -1011,7 +1012,7 @@ def forward( conditioning_scale=conditioning_scale, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, - do_control=do_control + do_control=do_control, ) hs_base.extend(residual_hb) hs_ctrl.extend(residual_hc) @@ -1025,7 +1026,7 @@ def forward( conditioning_scale=conditioning_scale, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, - do_control=do_control + do_control=do_control, ) # 3 - up @@ -1037,14 +1038,14 @@ def forward( hs_ctrl = hs_ctrl[:-n_resnets] h_base = up( hidden_states=h_base, - res_hidden_states_tuple_base= skips_hb, + res_hidden_states_tuple_base=skips_hb, res_hidden_states_tuple_ctrl=skips_hc, - temb= temb, + temb=temb, encoder_hidden_states=cemb, - conditioning_scale= conditioning_scale, - cross_attention_kwargs= cross_attention_kwargs, - attention_mask= attention_mask, - do_control=do_control + conditioning_scale=conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + do_control=do_control, ) # 4 - conv out diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index 90b91f033cdd..58413be95a8e 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -144,7 +144,7 @@ def __init__( super().__init__() if isinstance(unet, UNet2DConditionModel): - unet = UNetControlNetXSModel.from_unet2d(unet, controlnet) + unet = UNetControlNetXSModel.from_unet(unet, controlnet) if safety_checker is None and requires_safety_checker: logger.warning( diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index cf826d96ed05..2c13ae9671f0 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -178,7 +178,7 @@ def __init__( super().__init__() if isinstance(unet, UNet2DConditionModel): - unet = UNetControlNetXSModel.from_unet2d(unet, controlnet) + unet = UNetControlNetXSModel.from_unet(unet, controlnet) ( vae_compatible, From 618f3a2e2a27636d02533a4de73dce1dd477fe8c Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Mon, 25 Mar 2024 23:50:01 +0100 Subject: [PATCH 61/75] Fixed freezing --- src/diffusers/models/controlnet_xs.py | 137 +++++++++++++----- .../controlnet_xs/pipeline_controlnet_xs.py | 2 +- .../pipeline_controlnet_xs_sd_xl.py | 2 +- 3 files changed, 102 insertions(+), 39 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index ebc65b1ee3fa..f453d8957326 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -126,12 +126,12 @@ class ControlNetXSAddon(ModelMixin, ConfigMixin): learn_time_embedding (`bool`, defaults to `False`): Whether a time embedding should be learned. If yes, `ControlNetXSModel` will combine the time embeddings of the base model and the addon. If no, `ControlNetXSModel` will use the base model's time embedding. - channels_base (`Dict[str, List[Tuple[int]]]`, defaults to `ControlNetXSAddon.gather_base_subblock_sizes((320,640,1280,1280))`): - Channels of each subblock of the base model. Use `ControlNetXSAddon.gather_base_subblock_sizes` to obtain them. attention_head_dim (`list[int]`, defaults to `[4]`): The dimension of the attention heads. block_out_channels (`list[int]`, defaults to `[4, 8, 16, 16]`): The tuple of output channels for each block. + base_block_out_channels (`list[int]`, defaults to `[320, 640, 1280, 1280]`): + The tuple of output channels for each block in the base unet. cross_attention_dim (`int`, defaults to 1024): The dimension of the cross attention features. down_block_types (`list[str]`, defaults to `["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"]`): @@ -300,11 +300,10 @@ def from_unet( conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), ): r""" - todo umer Instantiate a [`ControlNetXSAddon`] from a [`UNet2DConditionModel`]. Parameters: - base_model (`UNet2DConditionModel`): + unet (`UNet2DConditionModel`): The UNet model we want to control. The dimensions of the ControlNetXSAddon will be adapted to it. size_ratio (float, *optional*, defaults to `None`): When given, block_out_channels is set to a fraction of the base model's block_out_channels. @@ -315,6 +314,14 @@ def from_unet( The dimension of the attention heads. The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. learn_time_embedding (`bool`, defaults to `False`): Whether the `ControlNetXSAddon` should learn a time embedding. + time_embedding_mix (`float`, defaults to 1.0): + If 0, then only the control addon's time embedding is used. + If 1, then only the base unet's time embedding is used. + Otherwise, both are combined. + conditioning_channels (`int`, defaults to 3): + Number of channels of conditioning input (e.g. an image) + conditioning_channel_order (`str`, defaults to `"rgb"`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. conditioning_embedding_out_channels (`Tuple[int]`, defaults to `(16, 32, 96, 256)`): The tuple of output channel for each block in the `controlnet_cond_embedding` layer. """ @@ -522,27 +529,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): `UNetControlNetXSModel` is compatible with StableDiffusion and StableDiffusion-XL. It's default parameters are compatible with StableDiffusion. - Most of it's paremeters are passed to the underlying `UNet2DConditionModel`. See it's documentation for details. - - Parameters: - time_embedding_mix (`float`, defaults to 1.0): - If 0, then only the control addon's time embedding is used. - If 1, then only the base unet's time embedding is used. - Otherwise, both are combined. - ctrl_conditioning_channels (`int`, defaults to 3): - The number of channels of the control conditioning input. - ctrl_conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): - Block sizes of the `ControlNetConditioningEmbedding`. - ctrl_conditioning_channel_order (`str`, defaults to "rgb"): - The order of channels in the control conditioning input. - ctrl_learn_time_embedding (`bool`, defaults to False): - Whether the control addon should learn a time embedding. Needs to be `True` if `time_embedding_mix` > 0. - ctrl_block_out_channels (`tuple[int]`, defaults to `(4, 8, 16, 16)`): - The tuple of output channels for each block in the control addon. - ctrl_attention_head_dim (`int` or `tuple[int]`, defaults to 4): - The dimension of the attention heads in the control addon. - ctrl_max_norm_num_groups (`int`, defaults to 32): - The maximum number of groups to use for the normalization in the control addon. Can be reduced to fit the block sizes. + It's parameters are either passed to the underlying `UNet2DConditionModel` or used exactly like in `ControlNetXSAddon` . See their documentation for details. """ _supports_gradient_checkpointing = True @@ -736,13 +723,35 @@ def from_unet( controlnet: Optional[ControlNetXSAddon] = None, size_ratio: Optional[float] = None, ctrl_block_out_channels: Optional[List[float]] = None, - time_embedding_mix: Optional[float] = 1.0, - # todo umer: pass kwargs to ctrlnet + time_embedding_mix: Optional[float] = None, + ctrl_optional_kwargs: Optional[Dict] = None, ): - # # validate input + r""" + Instantiate a [`UNetControlNetXSModel`] from a [`UNet2DConditionModel`] and an optional [`ControlNetXSAddon`] . + Parameters: + unet (`UNet2DConditionModel`): + The UNet model we want to control. + controlnet (`ControlNetXSAddon`): + The ConntrolNet-XS addon with which the UNet will be fused. If none is given, a new ConntrolNet-XS addon will be created. + size_ratio (float, *optional*, defaults to `None`): + Used to contruct the controlnet if none is given. See [`ControlNetXSAddon.from_unet`] for details. + ctrl_block_out_channels (`List[int]`, *optional*, defaults to `None`): + Used to contruct the controlnet if none is given. See [`ControlNetXSAddon.from_unet`] for details, where this parameter is called `block_out_channels`. + time_embedding_mix (`float`, *optional*, defaults to None): + Used to contruct the controlnet if none is given. See [`ControlNetXSAddon.from_unet`] for details. + ctrl_optional_kwargs (`Dict`, *optional*, defaults to `None`): + Passed to the `init` of the new controlent if no controlent was given. + """ if controlnet is None: - controlnet = ControlNetXSAddon.from_unet(unet, size_ratio, ctrl_block_out_channels) + controlnet = ControlNetXSAddon.from_unet(unet, size_ratio, ctrl_block_out_channels, **ctrl_optional_kwargs) + else: + if any( + o is not None for o in (size_ratio, ctrl_block_out_channels, time_embedding_mix, ctrl_optional_kwargs) + ): + raise ValueError( + "When a controlnet is passed, none of these parameters should be passed: size_ratio, ctrl_block_out_channels, time_embedding_mix, ctrl_optional_kwargs." + ) # # get params params_for_unet = [ @@ -822,18 +831,34 @@ def from_unet( return model - def freeze_unet2d_params(self) -> None: - # todo umer - """Freeze the weights of just the UNet2DConditionModel, and leave the ControlNetXSAddon - unfrozen for fine tuning. - """ + def freeze_unet_params(self) -> None: + """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine tuning.""" # Freeze everything for param in self.parameters(): - param.requires_grad = False + param.requires_grad = True # Unfreeze ControlNetXSAddon - for param in self.control_addon.parameters(): - param.requires_grad = True + base_parts = [ + "base_time_proj", + "base_time_embedding", + "base_class_embedding", + "base_add_time_proj", + "base_add_embedding", + "base_conv_in", + "base_conv_norm_out", + "base_conv_act", + "base_conv_out", + ] + base_parts = [getattr(self, part) for part in base_parts if getattr(self, part) is not None] + for part in base_parts: + for param in part.parameters(): + param.requires_grad = False + + for d in self.down_blocks: + d.freeze_base_params() + self.mid_block.freeze_base_params() + for u in self.up_blocks: + u.freeze_base_params() @torch.no_grad() def _check_if_vae_compatible(self, vae: AutoencoderKL): @@ -1233,6 +1258,20 @@ def get_first_cross_attention(block): return model + def freeze_base_params(self) -> None: + """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine tuning.""" + # Unfreeze everything + for param in self.parameters(): + param.requires_grad = True + + # Freeze base part + base_parts = [self.base_resnets, self.base_attentions] + if self.base_downsamplers is not None: + base_parts.append(self.base_downsamplers) + for part in base_parts: + for param in part.parameters(): + param.requires_grad = False + def forward( self, hidden_states_base: torch.FloatTensor, @@ -1415,6 +1454,16 @@ def get_first_cross_attention(midblock): return model + def freeze_base_params(self) -> None: + """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine tuning.""" + # Unfreeze everything + for param in self.parameters(): + param.requires_grad = True + + # Freeze base part + for param in self.base_midblock.parameters(): + param.requires_grad = False + def forward( self, hidden_states_base: torch.FloatTensor, @@ -1568,6 +1617,20 @@ def get_first_cross_attention(block): return model + def freeze_base_params(self) -> None: + """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine tuning.""" + # Unfreeze everything + for param in self.parameters(): + param.requires_grad = True + + # Freeze base part + base_parts = [self.resnets, self.attentions] + if self.upsamplers is not None: + base_parts.append(self.upsamplers) + for part in base_parts: + for param in part.parameters(): + param.requires_grad = False + def forward( self, hidden_states: torch.FloatTensor, diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index 58413be95a8e..0475ab1817ef 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -72,7 +72,7 @@ ... ) >>> pipe = StableDiffusionControlNetXSPipeline.from_pretrained( ... "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16 - ... ) # paper used time_embedding_mix=1.0 + ... ) >>> pipe.enable_model_cpu_offload() >>> # get canny image diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index 2c13ae9671f0..075ef444c286 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -81,7 +81,7 @@ ... ) >>> # initialize the models and pipeline - >>> controlnet_conditioning_scale = 0.5 # recommended for good generalization + >>> controlnet_conditioning_scale = 0.5 >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) >>> controlnet = ControlNetXSAddon.from_pretrained( ... "UmerHA/Testing-ConrolNetXS-SDXL-canny", torch_dtype=torch.float16 From 586fc181873591f74ff99b6041b58a53262e5871 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Tue, 26 Mar 2024 15:39:16 +0100 Subject: [PATCH 62/75] Added gradient ckpt'ing; fixed tests --- src/diffusers/models/controlnet_xs.py | 489 ++++++++++-------- .../unets/test_models_unet_controlnetxs.py | 242 ++++++--- .../controlnet_xs/test_controlnetxs.py | 9 +- .../controlnet_xs/test_controlnetxs_sdxl.py | 2 +- 4 files changed, 443 insertions(+), 299 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index f453d8957326..bf4e2f050ed0 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -17,11 +17,11 @@ import torch import torch.utils.checkpoint -from torch import nn +from torch import FloatTensor, nn from torch.nn import functional as F from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, logging +from ..utils import BaseOutput, is_torch_version, logging from .autoencoders import AutoencoderKL from .embeddings import ( TimestepEmbedding, @@ -48,12 +48,12 @@ class ControlNetXSOutput(BaseOutput): The output of [`UNetControlNetXSModel`]. Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + sample (`FloatTensor` of shape `(batch_size, num_channels, height, width)`): The output of the `UNetControlNetXSModel`. Unlike `ControlNetOutput` this is NOT to be added to the base model output, but is already the final output. """ - sample: torch.FloatTensor = None + sample: FloatTensor = None # copied from diffusers.models.controlnet.ControlNetConditioningEmbedding @@ -147,6 +147,157 @@ class ControlNetXSAddon(ModelMixin, ConfigMixin): Maximum number of groups in group normal. The actual number will the the largest divisor of the respective channels, that is <= max_norm_num_groups. """ + @register_to_config + def __init__( + self, + conditioning_channels: int = 3, + conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), + time_embedding_mix: float = 1.0, + learn_time_embedding: bool = False, + attention_head_dim: Union[int, Tuple[int]] = 4, + block_out_channels: Tuple[int] = (4, 8, 16, 16), + base_block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + cross_attention_dim: int = 1024, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + sample_size: Optional[int] = 96, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + upcast_attention: bool = True, + max_norm_num_groups: int = 32, + ): + super().__init__() + + self.sample_size = sample_size + + time_embedding_input_dim = base_block_out_channels[0] + time_embedding_dim = base_block_out_channels[0] * 4 + + # `num_attention_heads` defaults to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = attention_head_dim + + # Check inputs + if conditioning_channel_order not in ["rgb", "bgr"]: + raise ValueError(f"unknown `conditioning_channel_order`: {conditioning_channel_order}") + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + transformer_layers_per_block = repeat_if_not_list( + transformer_layers_per_block, repetitions=len(down_block_types) + ) + cross_attention_dim = repeat_if_not_list(cross_attention_dim, repetitions=len(down_block_types)) + num_attention_heads = repeat_if_not_list( + num_attention_heads, repetitions=len(down_block_types) + ) # todo umer: im using # attn heads & dim attn heads. should only be one. + attention_head_dim = repeat_if_not_list(attention_head_dim, repetitions=len(down_block_types)) + + if len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + # 5 - Create conditioning hint embedding + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + ) + + # time + if learn_time_embedding: + self.time_embedding = TimestepEmbedding(time_embedding_input_dim, time_embedding_dim) + else: + self.time_embedding = None + + self.time_embed_act = None + + self.down_blocks = nn.ModuleList([]) + self.up_connections = nn.ModuleList([]) + + # input + self.conv_in = nn.Conv2d(4, block_out_channels[0], kernel_size=3, padding=1) + self.control_to_base_for_conv_in = make_zero_conv(block_out_channels[0], base_block_out_channels[0]) + + # down + base_out_channels = base_block_out_channels[0] + ctrl_out_channels = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + base_in_channels = base_out_channels + base_out_channels = base_block_out_channels[i] + ctrl_in_channels = ctrl_out_channels + ctrl_out_channels = block_out_channels[i] + has_crossattn = "CrossAttn" in down_block_type + is_final_block = i == len(down_block_types) - 1 + + self.down_blocks.append( + ControlNetXSAddon.get_down_block( + base_in_channels=base_in_channels, + base_out_channels=base_out_channels, + ctrl_in_channels=ctrl_in_channels, + ctrl_out_channels=ctrl_out_channels, + temb_channels=time_embedding_dim, + max_norm_num_groups=max_norm_num_groups, + has_crossattn=has_crossattn, + transformer_layers_per_block=transformer_layers_per_block[i], + num_attention_heads=attention_head_dim[i], + cross_attention_dim=cross_attention_dim[i], + add_downsample=not is_final_block, + upcast_attention=upcast_attention, + ) + ) + + # mid + self.mid_block = ControlNetXSAddon.get_mid_block( + base_channels=base_block_out_channels[-1], + ctrl_channels=block_out_channels[-1], + temb_channels=time_embedding_dim, + transformer_layers_per_block=transformer_layers_per_block[-1], + num_attention_heads=attention_head_dim[-1], + cross_attention_dim=cross_attention_dim[-1], + upcast_attention=upcast_attention, + ) + + # up + # The skip connection channels are the output of the conv_in and of all the down subblocks + ctrl_skip_channels = [block_out_channels[0]] + for i, out_channels in enumerate(block_out_channels): + number_of_subblocks = ( + 3 if i < len(block_out_channels) - 1 else 2 + ) # every block has 3 subblocks, except last one, which has 2 as it has no downsampler + ctrl_skip_channels.extend([out_channels] * number_of_subblocks) + + reversed_base_block_out_channels = list(reversed(base_block_out_channels)) + + base_out_channels = reversed_base_block_out_channels[0] + for i in range(len(down_block_types)): + prev_base_output_channel = base_out_channels + base_out_channels = reversed_base_block_out_channels[i] + ctrl_skip_channels_ = [ctrl_skip_channels.pop() for _ in range(3)] + + self.up_connections.append( + ControlNetXSAddon.get_up_connections( + out_channels=base_out_channels, + prev_output_channel=prev_base_output_channel, + ctrl_skip_channels=ctrl_skip_channels_, + ) + ) + @staticmethod def get_down_block( base_in_channels: int, @@ -362,157 +513,6 @@ def from_unet( return model - @register_to_config - def __init__( - self, - conditioning_channels: int = 3, - conditioning_channel_order: str = "rgb", - conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), - time_embedding_mix: float = 1.0, - learn_time_embedding: bool = False, - attention_head_dim: Union[int, Tuple[int]] = 4, - block_out_channels: Tuple[int] = (4, 8, 16, 16), - base_block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - cross_attention_dim: int = 1024, - down_block_types: Tuple[str] = ( - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "DownBlock2D", - ), - sample_size: Optional[int] = 96, - transformer_layers_per_block: Union[int, Tuple[int]] = 1, - upcast_attention: bool = True, - max_norm_num_groups: int = 32, - ): - super().__init__() - - self.sample_size = sample_size - - time_embedding_input_dim = base_block_out_channels[0] - time_embedding_dim = base_block_out_channels[0] * 4 - - # `num_attention_heads` defaults to `attention_head_dim`. This looks weird upon first reading it and it is. - # The reason for this behavior is to correct for incorrectly named variables that were introduced - # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 - # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking - # which is why we correct for the naming here. - num_attention_heads = attention_head_dim - - # Check inputs - if conditioning_channel_order not in ["rgb", "bgr"]: - raise ValueError(f"unknown `conditioning_channel_order`: {conditioning_channel_order}") - - if len(block_out_channels) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." - ) - - transformer_layers_per_block = repeat_if_not_list( - transformer_layers_per_block, repetitions=len(down_block_types) - ) - cross_attention_dim = repeat_if_not_list(cross_attention_dim, repetitions=len(down_block_types)) - num_attention_heads = repeat_if_not_list( - num_attention_heads, repetitions=len(down_block_types) - ) # todo umer: im using # attn heads & dim attn heads. should only be one. - attention_head_dim = repeat_if_not_list(attention_head_dim, repetitions=len(down_block_types)) - - if len(num_attention_heads) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." - ) - - if len(attention_head_dim) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." - ) - - # 5 - Create conditioning hint embedding - self.controlnet_cond_embedding = ControlNetConditioningEmbedding( - conditioning_embedding_channels=block_out_channels[0], - block_out_channels=conditioning_embedding_out_channels, - conditioning_channels=conditioning_channels, - ) - - # time - if learn_time_embedding: - self.time_embedding = TimestepEmbedding(time_embedding_input_dim, time_embedding_dim) - else: - self.time_embedding = None - - self.time_embed_act = None - - self.down_blocks = nn.ModuleList([]) - self.up_connections = nn.ModuleList([]) - - # input - self.conv_in = nn.Conv2d(4, block_out_channels[0], kernel_size=3, padding=1) - self.control_to_base_for_conv_in = make_zero_conv(block_out_channels[0], base_block_out_channels[0]) - - # down - base_out_channels = base_block_out_channels[0] - ctrl_out_channels = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - base_in_channels = base_out_channels - base_out_channels = base_block_out_channels[i] - ctrl_in_channels = ctrl_out_channels - ctrl_out_channels = block_out_channels[i] - has_crossattn = "CrossAttn" in down_block_type - is_final_block = i == len(down_block_types) - 1 - - self.down_blocks.append( - ControlNetXSAddon.get_down_block( - base_in_channels=base_in_channels, - base_out_channels=base_out_channels, - ctrl_in_channels=ctrl_in_channels, - ctrl_out_channels=ctrl_out_channels, - temb_channels=time_embedding_dim, - max_norm_num_groups=max_norm_num_groups, - has_crossattn=has_crossattn, - transformer_layers_per_block=transformer_layers_per_block[i], - num_attention_heads=attention_head_dim[i], - cross_attention_dim=cross_attention_dim[i], - add_downsample=not is_final_block, - upcast_attention=upcast_attention, - ) - ) - - # mid - self.mid_block = ControlNetXSAddon.get_mid_block( - base_channels=base_block_out_channels[-1], - ctrl_channels=block_out_channels[-1], - temb_channels=time_embedding_dim, - transformer_layers_per_block=transformer_layers_per_block[-1], - num_attention_heads=attention_head_dim[-1], - cross_attention_dim=cross_attention_dim[-1], - upcast_attention=upcast_attention, - ) - - # up - # The skip connection channels are the output of the conv_in and of all the down subblocks - ctrl_skip_channels = [block_out_channels[0]] - for i, out_channels in enumerate(block_out_channels): - number_of_subblocks = ( - 3 if i < len(block_out_channels) - 1 else 2 - ) # every block has 3 subblocks, except last one, which has 2 as it has no downsampler - ctrl_skip_channels.extend([out_channels] * number_of_subblocks) - - reversed_base_block_out_channels = list(reversed(base_block_out_channels)) - - base_out_channels = reversed_base_block_out_channels[0] - for i in range(len(down_block_types)): - prev_base_output_channel = base_out_channels - base_out_channels = reversed_base_block_out_channels[i] - ctrl_skip_channels_ = [ctrl_skip_channels.pop() for _ in range(3)] - - self.up_connections.append( - ControlNetXSAddon.get_up_connections( - out_channels=base_out_channels, - prev_output_channel=prev_base_output_channel, - ctrl_skip_channels=ctrl_skip_channels_, - ) - ) - def forward(self, *args, **kwargs): raise ValueError( "A ControlNetXSAddonModel cannot be run by itself. Pass it into a ControlNetXSModel model instead." @@ -783,7 +783,7 @@ def from_unet( "max_norm_num_groups", ] params_for_controlnet = {"ctrl_" + k: v for k, v in controlnet.config.items() if k in params_for_controlnet} - params_for_controlnet["time_embedding_mix"] = time_embedding_mix + params_for_controlnet["time_embedding_mix"] = controlnet.config.time_embedding_mix # # create model model = cls.from_config({**params_for_unet, **params_for_controlnet}) @@ -873,7 +873,7 @@ def _set_gradient_checkpointing(self, module, value=False): def forward( self, - sample: torch.FloatTensor, + sample: FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, controlnet_cond: Optional[torch.Tensor] = None, @@ -890,13 +890,13 @@ def forward( The [`ControlNetXSModel`] forward method. Args: - sample (`torch.FloatTensor`): + sample (`FloatTensor`): The noisy input tensor. timestep (`Union[torch.Tensor, float, int]`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.Tensor`): The encoder hidden states. - controlnet_cond (`torch.FloatTensor`): + controlnet_cond (`FloatTensor`): The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. conditioning_scale (`float`, defaults to `1.0`): How much the control model affects the base model outputs. @@ -1265,7 +1265,9 @@ def freeze_base_params(self) -> None: param.requires_grad = True # Freeze base part - base_parts = [self.base_resnets, self.base_attentions] + base_parts = [self.base_resnets] + if isinstance(self.base_attentions, nn.ModuleList): # attentions can be a list of Nones + base_parts.append(self.base_attentions) if self.base_downsamplers is not None: base_parts.append(self.base_downsamplers) for part in base_parts: @@ -1274,16 +1276,16 @@ def freeze_base_params(self) -> None: def forward( self, - hidden_states_base: torch.FloatTensor, - temb: torch.FloatTensor, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - hidden_states_ctrl: Optional[torch.FloatTensor] = None, + hidden_states_base: FloatTensor, + temb: FloatTensor, + encoder_hidden_states: Optional[FloatTensor] = None, + hidden_states_ctrl: Optional[FloatTensor] = None, conditioning_scale: Optional[float] = 1.0, - attention_mask: Optional[torch.FloatTensor] = None, + attention_mask: Optional[FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[FloatTensor] = None, do_control: bool = True, - ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: # todo umer: output type hint correct? + ) -> Tuple[FloatTensor, FloatTensor, Tuple[FloatTensor, ...], Tuple[FloatTensor, ...]]: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") @@ -1297,21 +1299,52 @@ def forward( base_blocks = list(zip(self.base_resnets, self.base_attentions)) ctrl_blocks = list(zip(self.ctrl_resnets, self.ctrl_attentions)) + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + def apply_resnet(resnet, hidden_states, temb): + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + if self.training and self.gradient_checkpointing: + return torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + return resnet(hidden_states, temb) + for (b_res, b_attn), (c_res, c_attn), b2c, c2b in zip( base_blocks, ctrl_blocks, self.base_to_ctrl, self.ctrl_to_base ): - if self.training and self.gradient_checkpointing: - raise NotImplementedError("todo umer") - else: - # concat base -> ctrl - if do_control: - h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) - - # apply base subblock - h_base = b_res(h_base, temb) - if b_attn is not None: - h_base = b_attn( - h_base, + # concat base -> ctrl + if do_control: + h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) + + # apply base subblock + h_base = apply_resnet(b_res, h_base, temb) + if b_attn is not None: + h_base = b_attn( + h_base, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # apply ctrl subblock + if do_control: + h_ctrl = apply_resnet(c_res, h_ctrl, temb) + if c_attn is not None: + h_ctrl = c_attn( + h_ctrl, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, @@ -1319,22 +1352,9 @@ def forward( return_dict=False, )[0] - # apply ctrl subblock - if do_control: - h_ctrl = c_res(h_ctrl, temb) - if c_attn is not None: - h_ctrl = c_attn( - h_ctrl, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - - # add ctrl -> base - if do_control: - h_base = h_base + c2b(h_ctrl) * conditioning_scale + # add ctrl -> base + if do_control: + h_base = h_base + c2b(h_ctrl) * conditioning_scale base_output_states = base_output_states + (h_base,) ctrl_output_states = ctrl_output_states + (h_ctrl,) @@ -1466,16 +1486,16 @@ def freeze_base_params(self) -> None: def forward( self, - hidden_states_base: torch.FloatTensor, - temb: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor, - hidden_states_ctrl: Optional[torch.FloatTensor] = None, + hidden_states_base: FloatTensor, + temb: FloatTensor, + encoder_hidden_states: FloatTensor, + hidden_states_ctrl: Optional[FloatTensor] = None, conditioning_scale: Optional[float] = 1.0, cross_attention_kwargs: Optional[Dict[str, Any]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, + attention_mask: Optional[FloatTensor] = None, + encoder_attention_mask: Optional[FloatTensor] = None, do_control: bool = True, - ) -> torch.FloatTensor: # todo umer: output type hint correct? + ) -> Tuple[FloatTensor, FloatTensor]: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") @@ -1624,7 +1644,9 @@ def freeze_base_params(self) -> None: param.requires_grad = True # Freeze base part - base_parts = [self.resnets, self.attentions] + base_parts = [self.resnets] + if isinstance(self.attentions, nn.ModuleList): # attentions can be a list of Nones + base_parts.append(self.attentions) if self.upsamplers is not None: base_parts.append(self.upsamplers) for part in base_parts: @@ -1633,18 +1655,18 @@ def freeze_base_params(self) -> None: def forward( self, - hidden_states: torch.FloatTensor, - res_hidden_states_tuple_base: Tuple[torch.FloatTensor, ...], # todo umer: why ... in type hint? - res_hidden_states_tuple_ctrl: Tuple[torch.FloatTensor, ...], # todo umer: why ... in type hint? - temb: torch.FloatTensor, - encoder_hidden_states: Optional[torch.FloatTensor] = None, + hidden_states: FloatTensor, + res_hidden_states_tuple_base: Tuple[FloatTensor, ...], + res_hidden_states_tuple_ctrl: Tuple[FloatTensor, ...], + temb: FloatTensor, + encoder_hidden_states: Optional[FloatTensor] = None, conditioning_scale: Optional[float] = 1.0, cross_attention_kwargs: Optional[Dict[str, Any]] = None, - attention_mask: Optional[torch.FloatTensor] = None, + attention_mask: Optional[FloatTensor] = None, upsample_size: Optional[int] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[FloatTensor] = None, do_control: bool = True, - ) -> torch.FloatTensor: # todo umer: output type hint correct? + ) -> FloatTensor: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") @@ -1660,6 +1682,27 @@ def forward( resnet_with_upsampler = self.resnets[-1] attn_with_upsampler = self.attentions[-1] + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + def apply_resnet(resnet, hidden_states, temb): + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + if self.training and self.gradient_checkpointing: + return torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + return resnet(hidden_states, temb) + for resnet, attn, c2b, res_h_base, res_h_ctrl in zip( resnets_without_upsampler, attn_without_upsampler, @@ -1670,20 +1713,16 @@ def forward( if do_control: hidden_states += c2b(res_h_ctrl) * conditioning_scale hidden_states = torch.cat([hidden_states, res_h_base], dim=1) - - if self.training and self.gradient_checkpointing: - raise NotImplementedError("todo umer") - else: - hidden_states = resnet(hidden_states, temb) - if attn is not None: - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] + hidden_states = apply_resnet(resnet, hidden_states, temb) + if attn is not None: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] if self.upsamplers is not None: c2b = self.ctrl_to_base[-1] diff --git a/tests/models/unets/test_models_unet_controlnetxs.py b/tests/models/unets/test_models_unet_controlnetxs.py index 2bdcd7aef42b..ddbddf6b6024 100644 --- a/tests/models/unets/test_models_unet_controlnetxs.py +++ b/tests/models/unets/test_models_unet_controlnetxs.py @@ -14,11 +14,11 @@ # limitations under the License. import copy -import re import unittest import numpy as np import torch +from torch import nn from diffusers import ControlNetXSAddon, UNet2DConditionModel, UNetControlNetXSModel from diffusers.utils import logging @@ -73,16 +73,14 @@ def prepare_init_args_and_inputs_for_common(self): "sample_size": 32, "down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"), "up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"), - "block_out_channels": (4, 8), - "norm_num_groups": 1, + "block_out_channels": (32, 64), "cross_attention_dim": 32, "transformer_layers_per_block": 1, "num_attention_heads": 8, "upcast_attention": False, - "ctrl_time_embedding_input_dim": 4, "ctrl_block_out_channels": [4, 8], "ctrl_attention_head_dim": 8, - "ctrl_max_norm_num_groups": 1, + "ctrl_max_norm_num_groups": 4, } inputs_dict = self.dummy_input return init_dict, inputs_dict @@ -90,7 +88,7 @@ def prepare_init_args_and_inputs_for_common(self): def get_dummy_unet(self): """For some tests we also need the underlying UNet. For these, we'll build the UNetControlNetXSModel from the UNet""" return UNet2DConditionModel( - block_out_channels=(4, 8), + block_out_channels=(32, 64), layers_per_block=2, sample_size=32, in_channels=4, @@ -98,69 +96,173 @@ def get_dummy_unet(self): down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), cross_attention_dim=32, - norm_num_groups=1, use_linear_projection=True, ) - def test_from_unet2d(self): + def test_from_unet(self): unet = self.get_dummy_unet() controlnet = ControlNetXSAddon.from_unet(unet, size_ratio=1) - model = UNetControlNetXSModel.from_unet2d(unet, controlnet) + model = UNetControlNetXSModel.from_unet(unet, controlnet) model_state_dict = model.state_dict() - def is_decomposed(module_name): - return "down_block" in module_name or "up_block" in module_name - - def block_to_subblock_name(param_name): - """ - Map name of a param from 'block notation' as in UNet to 'subblock notation' as in UNetControlNetXS - e.g. 'down_blocks.1.attentions.0.proj_in.weight' -> 'base_down_subblocks.3.attention.proj_in.weight' - """ - param_name = param_name.replace("down_blocks", "base_down_subblocks") - param_name = param_name.replace("up_blocks", "base_up_subblocks") - - numbers = re.findall(r"\d+", param_name) - block_idx, module_idx = int(numbers[0]), int(numbers[1]) - - layers_per_block = 2 - subblocks_per_block = layers_per_block + 1 # include down/upsampler - - if "downsampler" in param_name or "upsampler" in param_name: - subblock_idx = block_idx * subblocks_per_block + layers_per_block - else: - subblock_idx = block_idx * subblocks_per_block + module_idx - - param_name = re.sub(r"\d", str(subblock_idx), param_name, count=1) - param_name = re.sub(r"resnets\.\d+", "resnet", param_name) # eg resnets.1 -> resnet - param_name = re.sub(r"attentions\.\d+", "attention", param_name) # eg attentions.1 -> attention - param_name = re.sub(r"downsamplers\.\d+", "downsampler", param_name) # eg attentions.1 -> attention - param_name = re.sub(r"upsamplers\.\d+", "upsampler", param_name) # eg attentions.1 -> attention - - return param_name - - for param_name, param_value in unet.named_parameters(): - if is_decomposed(param_name): - # check unet modules that were decomposed - self.assertTrue(torch.equal(model_state_dict[block_to_subblock_name(param_name)], param_value)) - else: - # check unet modules that were copied as is - self.assertTrue(torch.equal(model_state_dict["base_" + param_name], param_value)) - - # check controlnet - for param_name, param_value in controlnet.named_parameters(): - self.assertTrue(torch.equal(model_state_dict["control_addon." + param_name], param_value)) - - def test_freeze_unet2d(self): + def assert_equal_weights(module, weight_dict_prefix): + for param_name, param_value in module.named_parameters(): + assert torch.equal(model_state_dict[weight_dict_prefix + "." + param_name], param_value) + + # # check unet + # everything expect down,mid,up blocks + modules_from_unet = [ + "time_embedding", + "conv_in", + "conv_norm_out", + "conv_out", + ] + for p in modules_from_unet: + assert_equal_weights(getattr(unet, p), "base_" + p) + optional_modules_from_unet = [ + "class_embedding", + "add_time_proj", + "add_embedding", + ] + for p in optional_modules_from_unet: + if hasattr(unet, p) and getattr(unet, p) is not None: + assert_equal_weights(getattr(unet, p), "base_" + p) + # down blocks + assert len(unet.down_blocks) == len(model.down_blocks) + for i, d in enumerate(unet.down_blocks): + assert_equal_weights(d.resnets, f"down_blocks.{i}.base_resnets") + if hasattr(d, "attentions"): + assert_equal_weights(d.attentions, f"down_blocks.{i}.base_attentions") + if hasattr(d, "downsamplers") and getattr(d, "downsamplers") is not None: + assert_equal_weights(d.downsamplers[0], f"down_blocks.{i}.base_downsamplers") + # mid block + assert_equal_weights(unet.mid_block, "mid_block.base_midblock") + # up blocks + assert len(unet.up_blocks) == len(model.up_blocks) + for i, u in enumerate(unet.up_blocks): + assert_equal_weights(u.resnets, f"up_blocks.{i}.resnets") + if hasattr(u, "attentions"): + assert_equal_weights(u.attentions, f"up_blocks.{i}.attentions") + if hasattr(u, "upsamplers") and getattr(u, "upsamplers") is not None: + assert_equal_weights(u.upsamplers[0], f"up_blocks.{i}.upsamplers") + + # # check controlnet + # everything expect down,mid,up blocks + modules_from_controlnet = { + "controlnet_cond_embedding": "controlnet_cond_embedding", + "conv_in": "ctrl_conv_in", + "control_to_base_for_conv_in": "control_to_base_for_conv_in", + } + optional_modules_from_controlnet = {"time_embedding": "ctrl_time_embedding"} + for name_in_controlnet, name_in_unetcnxs in modules_from_controlnet.items(): + assert_equal_weights(getattr(controlnet, name_in_controlnet), name_in_unetcnxs) + + for name_in_controlnet, name_in_unetcnxs in optional_modules_from_controlnet.items(): + if hasattr(controlnet, name_in_controlnet) and getattr(controlnet, name_in_controlnet) is not None: + assert_equal_weights(getattr(controlnet, name_in_controlnet), name_in_unetcnxs) + # down blocks + assert len(controlnet.down_blocks) == len(model.down_blocks) + for i, d in enumerate(controlnet.down_blocks): + assert_equal_weights(d["resnets"], f"down_blocks.{i}.ctrl_resnets") + assert_equal_weights(d["base_to_ctrl"], f"down_blocks.{i}.base_to_ctrl") + assert_equal_weights(d["ctrl_to_base"], f"down_blocks.{i}.ctrl_to_base") + if "attentions" in d: + assert_equal_weights(d["attentions"], f"down_blocks.{i}.ctrl_attentions") + if "downsamplers" in d: + assert_equal_weights(d["downsamplers"], f"down_blocks.{i}.ctrl_downsamplers") + # mid block + assert_equal_weights(controlnet.mid_block["base_to_ctrl"], "mid_block.base_to_ctrl") + assert_equal_weights(controlnet.mid_block["midblock"], "mid_block.ctrl_midblock") + assert_equal_weights(controlnet.mid_block["ctrl_to_base"], "mid_block.ctrl_to_base") + # up blocks + assert len(controlnet.up_connections) == len(model.up_blocks) + for i, u in enumerate(controlnet.up_connections): + assert_equal_weights(u, f"up_blocks.{i}.ctrl_to_base") + + def test_freeze_unet(self): + def assert_frozen(module): + for p in module.parameters(): + assert not p.requires_grad + + def assert_unfrozen(module): + for p in module.parameters(): + assert p.requires_grad + init_dict, _ = self.prepare_init_args_and_inputs_for_common() model = UNetControlNetXSModel(**init_dict) - model.freeze_unet2d_params() - - for param_name, param_value in model.named_parameters(): - if "control_addon" not in param_name: - self.assertFalse(param_value.requires_grad) - else: - self.assertTrue(param_value.requires_grad) + model.freeze_unet_params() + + # # check unet + # everything expect down,mid,up blocks + modules_from_unet = [ + model.base_time_embedding, + model.base_conv_in, + model.base_conv_norm_out, + model.base_conv_out, + ] + for m in modules_from_unet: + assert_frozen(m) + + optional_modules_from_unet = [ + model.base_class_embedding, + model.base_add_time_proj, + model.base_add_embedding, + ] + for m in optional_modules_from_unet: + if m is not None: + assert_frozen(m) + + # down blocks + for i, d in enumerate(model.down_blocks): + assert_frozen(d.base_resnets) + if isinstance(d.base_attentions, nn.ModuleList): # attentions can be list of Nones + assert_frozen(d.base_attentions) + if d.base_downsamplers is not None: + assert_frozen(d.base_downsamplers) + + # mid block + assert_frozen(model.mid_block.base_midblock) + + # up blocks + for i, u in enumerate(model.up_blocks): + assert_frozen(u.resnets) + if isinstance(u.attentions, nn.ModuleList): # attentions can be list of Nones + assert_frozen(u.attentions) + if u.upsamplers is not None: + assert_frozen(u.upsamplers) + + # # check controlnet + # everything expect down,mid,up blocks + modules_from_controlnet = [ + model.controlnet_cond_embedding, + model.ctrl_conv_in, + model.control_to_base_for_conv_in, + ] + optional_modules_from_controlnet = [model.ctrl_time_embedding] + + for m in modules_from_controlnet: + assert_unfrozen(m) + for m in optional_modules_from_controlnet: + if m is not None: + assert_unfrozen(m) + + # down blocks + for d in model.down_blocks: + assert_unfrozen(d.ctrl_resnets) + assert_unfrozen(d.base_to_ctrl) + assert_unfrozen(d.ctrl_to_base) + if isinstance(d.ctrl_attentions, nn.ModuleList): # attentions can be list of Nones + assert_unfrozen(d.ctrl_attentions) + if d.ctrl_downsamplers is not None: + assert_unfrozen(d.ctrl_downsamplers) + # mid block + assert_unfrozen(model.mid_block.base_to_ctrl) + assert_unfrozen(model.mid_block.ctrl_midblock) + assert_unfrozen(model.mid_block.ctrl_to_base) + # up blocks + for u in model.up_blocks: + assert_unfrozen(u.ctrl_to_base) def test_gradient_checkpointing_is_applied(self): model_class_copy = copy.copy(UNetControlNetXSModel) @@ -187,9 +289,9 @@ def _set_gradient_checkpointing_new(self, module, value=False): EXPECTED_SET = { "Transformer2DModel", "UNetMidBlock2DCrossAttn", - "CrossAttnDownSubBlock2D", - "DownSubBlock2D", - "CrossAttnUpSubBlock2D", + "ControlNetXSCrossAttnDownBlock2D", + "ControlNetXSCrossAttnMidBlock2D", + "ControlNetXSCrossAttnUpBlock2D", } assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET @@ -199,7 +301,7 @@ def test_forward_no_control(self): unet = self.get_dummy_unet() controlnet = ControlNetXSAddon.from_unet(unet, size_ratio=1) - model = UNetControlNetXSModel.from_unet2d(unet, controlnet) + model = UNetControlNetXSModel.from_unet(unet, controlnet) unet = unet.to(torch_device) model = model.to(torch_device) @@ -213,15 +315,17 @@ def test_forward_no_control(self): unet_output = unet(**input_for_unet).sample.cpu() unet_controlnet_output = model(**input_, do_control=False).sample.cpu() - assert np.abs(unet_output.flatten() - unet_controlnet_output.flatten()).max() < 1e-5 + assert np.abs(unet_output.flatten() - unet_controlnet_output.flatten()).max() < 3e-4 def test_time_embedding_mixing(self): unet = self.get_dummy_unet() controlnet = ControlNetXSAddon.from_unet(unet, size_ratio=1) - controlnet_mix_time = ControlNetXSAddon.from_unet(unet, size_ratio=1, time_embedding_mix=0.5) + controlnet_mix_time = ControlNetXSAddon.from_unet( + unet, size_ratio=1, time_embedding_mix=0.5, learn_time_embedding=True + ) - model = UNetControlNetXSModel.from_unet2d(unet, controlnet) - model_mix_time = UNetControlNetXSModel.from_unet2d(unet, controlnet_mix_time) + model = UNetControlNetXSModel.from_unet(unet, controlnet) + model_mix_time = UNetControlNetXSModel.from_unet(unet, controlnet_mix_time) unet = unet.to(torch_device) model = model.to(torch_device) @@ -234,3 +338,7 @@ def test_time_embedding_mixing(self): output_mix_time = model_mix_time(**input_).sample assert output.shape == output_mix_time.shape + + def test_forward_with_norm_groups(self): + # UNetControlNetXSModel currently only supports StableDiffusion and StableDiffusion-XL, both of which have norm_num_groups fixed at 32. So we don't need to test different values for norm_num_groups. + pass diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py index eeebe544d8d4..5a77dbf20cdc 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py @@ -125,7 +125,7 @@ class ControlNetXSPipelineFastTests( def get_dummy_components(self, time_cond_proj_dim=None): torch.manual_seed(0) unet = UNet2DConditionModel( - block_out_channels=(4, 8), + block_out_channels=(32, 64), layers_per_block=2, sample_size=32, in_channels=4, @@ -133,13 +133,12 @@ def get_dummy_components(self, time_cond_proj_dim=None): down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), cross_attention_dim=32, - norm_num_groups=1, time_cond_proj_dim=time_cond_proj_dim, use_linear_projection=True, ) torch.manual_seed(0) controlnet = ControlNetXSAddon.from_unet( - base_model=unet, + unet=unet, size_ratio=0.5, num_attention_heads=2, learn_time_embedding=True, @@ -240,9 +239,7 @@ def test_controlnet_lcm(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array( - [0.52700454, 0.3930534, 0.25509018, 0.7132304, 0.53696585, 0.46568912, 0.7095368, 0.7059624, 0.4744786] - ) + expected_slice = np.array([0.491, 0.411, 0.292, 0.631, 0.506, 0.439, 0.664, 0.67, 0.447]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py index 7295a7006eac..99d7409b6465 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py @@ -93,7 +93,7 @@ def get_dummy_components(self): ) torch.manual_seed(0) controlnet = ControlNetXSAddon.from_unet( - base_model=unet, + unet=unet, size_ratio=0.5, learn_time_embedding=True, conditioning_embedding_out_channels=(16, 32), From 34e9fdfb10b78d77af99597a0a284a7f53805b78 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Tue, 26 Mar 2024 17:40:22 +0100 Subject: [PATCH 63/75] Fix slow tests(+compile) ; clear naming confusion --- src/diffusers/models/controlnet_xs.py | 101 ++++++++---------- .../controlnet_xs/pipeline_controlnet_xs.py | 11 +- .../pipeline_controlnet_xs_sd_xl.py | 11 +- .../unets/test_models_unet_controlnetxs.py | 2 +- .../controlnet_xs/test_controlnetxs.py | 30 ++++-- .../controlnet_xs/test_controlnetxs_sdxl.py | 14 ++- 6 files changed, 79 insertions(+), 90 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index bf4e2f050ed0..aadac0ac83f8 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -126,8 +126,8 @@ class ControlNetXSAddon(ModelMixin, ConfigMixin): learn_time_embedding (`bool`, defaults to `False`): Whether a time embedding should be learned. If yes, `ControlNetXSModel` will combine the time embeddings of the base model and the addon. If no, `ControlNetXSModel` will use the base model's time embedding. - attention_head_dim (`list[int]`, defaults to `[4]`): - The dimension of the attention heads. + num_attention_heads (`list[int]`, defaults to `[4]`): + The number of attention heads. block_out_channels (`list[int]`, defaults to `[4, 8, 16, 16]`): The tuple of output channels for each block. base_block_out_channels (`list[int]`, defaults to `[320, 640, 1280, 1280]`): @@ -155,7 +155,7 @@ def __init__( conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), time_embedding_mix: float = 1.0, learn_time_embedding: bool = False, - attention_head_dim: Union[int, Tuple[int]] = 4, + num_attention_heads: Union[int, Tuple[int]] = 4, block_out_channels: Tuple[int] = (4, 8, 16, 16), base_block_out_channels: Tuple[int] = (320, 640, 1280, 1280), cross_attention_dim: int = 1024, @@ -177,13 +177,6 @@ def __init__( time_embedding_input_dim = base_block_out_channels[0] time_embedding_dim = base_block_out_channels[0] * 4 - # `num_attention_heads` defaults to `attention_head_dim`. This looks weird upon first reading it and it is. - # The reason for this behavior is to correct for incorrectly named variables that were introduced - # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 - # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking - # which is why we correct for the naming here. - num_attention_heads = attention_head_dim - # Check inputs if conditioning_channel_order not in ["rgb", "bgr"]: raise ValueError(f"unknown `conditioning_channel_order`: {conditioning_channel_order}") @@ -197,21 +190,14 @@ def __init__( transformer_layers_per_block, repetitions=len(down_block_types) ) cross_attention_dim = repeat_if_not_list(cross_attention_dim, repetitions=len(down_block_types)) - num_attention_heads = repeat_if_not_list( - num_attention_heads, repetitions=len(down_block_types) - ) # todo umer: im using # attn heads & dim attn heads. should only be one. - attention_head_dim = repeat_if_not_list(attention_head_dim, repetitions=len(down_block_types)) + # see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why `ControlNetXSAddon` takes `num_attention_heads` instead of `attention_head_dim` + num_attention_heads = repeat_if_not_list(num_attention_heads, repetitions=len(down_block_types)) if len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) - if len(attention_head_dim) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." - ) - # 5 - Create conditioning hint embedding self.controlnet_cond_embedding = ControlNetConditioningEmbedding( conditioning_embedding_channels=block_out_channels[0], @@ -255,7 +241,7 @@ def __init__( max_norm_num_groups=max_norm_num_groups, has_crossattn=has_crossattn, transformer_layers_per_block=transformer_layers_per_block[i], - num_attention_heads=attention_head_dim[i], + num_attention_heads=num_attention_heads[i], cross_attention_dim=cross_attention_dim[i], add_downsample=not is_final_block, upcast_attention=upcast_attention, @@ -268,7 +254,7 @@ def __init__( ctrl_channels=block_out_channels[-1], temb_channels=time_embedding_dim, transformer_layers_per_block=transformer_layers_per_block[-1], - num_attention_heads=attention_head_dim[-1], + num_attention_heads=num_attention_heads[-1], cross_attention_dim=cross_attention_dim[-1], upcast_attention=upcast_attention, ) @@ -497,7 +483,7 @@ def from_unet( conditioning_embedding_out_channels=conditioning_embedding_out_channels, time_embedding_mix=time_embedding_mix, learn_time_embedding=learn_time_embedding, - attention_head_dim=num_attention_heads, + num_attention_heads=num_attention_heads, block_out_channels=block_out_channels, base_block_out_channels=unet.config.block_out_channels, cross_attention_dim=unet.config.cross_attention_dim, @@ -564,7 +550,7 @@ def __init__( ctrl_conditioning_channel_order: str = "rgb", ctrl_learn_time_embedding: bool = False, ctrl_block_out_channels: Tuple[int] = (4, 8, 16, 16), - ctrl_attention_head_dim: Union[int, Tuple[int]] = 4, # todo umer: # attn heads or dim attn heads? + ctrl_num_attention_heads: Union[int, Tuple[int]] = 4, ctrl_max_norm_num_groups: int = 32, ): super().__init__() @@ -581,7 +567,7 @@ def __init__( ) cross_attention_dim = repeat_if_not_list(cross_attention_dim, repetitions=len(down_block_types)) base_num_attention_heads = repeat_if_not_list(num_attention_heads, repetitions=len(down_block_types)) - ctrl_attention_head_dim = repeat_if_not_list(ctrl_attention_head_dim, repetitions=len(down_block_types)) + ctrl_num_attention_heads = repeat_if_not_list(ctrl_num_attention_heads, repetitions=len(down_block_types)) # Create UNet and decompose it into subblocks, which we then save base_model = UNet2DConditionModel( @@ -653,7 +639,7 @@ def __init__( has_crossattn=has_crossattn, transformer_layers_per_block=transformer_layers_per_block[i], base_num_attention_heads=base_num_attention_heads[i], - ctrl_num_attention_heads=ctrl_attention_head_dim[i], + ctrl_num_attention_heads=ctrl_num_attention_heads[i], cross_attention_dim=cross_attention_dim[i], add_downsample=not is_final_block, upcast_attention=upcast_attention, @@ -667,7 +653,7 @@ def __init__( temb_channels=time_embed_dim, transformer_layers_per_block=transformer_layers_per_block[-1], base_num_attention_heads=base_num_attention_heads[-1], - ctrl_num_attention_heads=ctrl_attention_head_dim[-1], + ctrl_num_attention_heads=ctrl_num_attention_heads[-1], cross_attention_dim=cross_attention_dim[-1], upcast_attention=upcast_attention, ) @@ -779,7 +765,7 @@ def from_unet( "conditioning_channel_order", "learn_time_embedding", "block_out_channels", - "attention_head_dim", + "num_attention_heads", "max_norm_num_groups", ] params_for_controlnet = {"ctrl_" + k: v for k, v in controlnet.config.items() if k in params_for_controlnet} @@ -1308,18 +1294,6 @@ def custom_forward(*inputs): return custom_forward - def apply_resnet(resnet, hidden_states, temb): - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - if self.training and self.gradient_checkpointing: - return torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) - else: - return resnet(hidden_states, temb) - for (b_res, b_attn), (c_res, c_attn), b2c, c2b in zip( base_blocks, ctrl_blocks, self.base_to_ctrl, self.ctrl_to_base ): @@ -1328,7 +1302,17 @@ def apply_resnet(resnet, hidden_states, temb): h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) # apply base subblock - h_base = apply_resnet(b_res, h_base, temb) + if self.training and self.gradient_checkpointing: + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + h_base = torch.utils.checkpoint.checkpoint( + create_custom_forward(b_res), + h_base, + temb, + **ckpt_kwargs, + ) + else: + h_base = b_res(h_base, temb) + if b_attn is not None: h_base = b_attn( h_base, @@ -1341,7 +1325,16 @@ def apply_resnet(resnet, hidden_states, temb): # apply ctrl subblock if do_control: - h_ctrl = apply_resnet(c_res, h_ctrl, temb) + if self.training and self.gradient_checkpointing: + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + h_ctrl = torch.utils.checkpoint.checkpoint( + create_custom_forward(c_res), + h_ctrl, + temb, + **ckpt_kwargs, + ) + else: + h_ctrl = c_res(h_ctrl, temb) if c_attn is not None: h_ctrl = c_attn( h_ctrl, @@ -1691,18 +1684,6 @@ def custom_forward(*inputs): return custom_forward - def apply_resnet(resnet, hidden_states, temb): - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - if self.training and self.gradient_checkpointing: - return torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) - else: - return resnet(hidden_states, temb) - for resnet, attn, c2b, res_h_base, res_h_ctrl in zip( resnets_without_upsampler, attn_without_upsampler, @@ -1712,8 +1693,20 @@ def apply_resnet(resnet, hidden_states, temb): ): if do_control: hidden_states += c2b(res_h_ctrl) * conditioning_scale + hidden_states = torch.cat([hidden_states, res_h_base], dim=1) - hidden_states = apply_resnet(resnet, hidden_states, temb) + + if self.training and self.gradient_checkpointing: + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + hidden_states = resnet(hidden_states, temb) + if attn is not None: hidden_states = attn( hidden_states, diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index 0475ab1817ef..85ed6f1be173 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -543,7 +543,7 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) - # Check `image` + # Check `image` and `controlnet_conditioning_scale` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.unet, torch._dynamo.eval_frame.OptimizedModule ) @@ -553,15 +553,6 @@ def check_inputs( and isinstance(self.unet._orig_mod, UNetControlNetXSModel) ): self.check_image(image, prompt, prompt_embeds) - else: - assert False - - # Check `controlnet_conditioning_scale` - if ( - isinstance(self.unet, UNetControlNetXSModel) - or is_compiled - and isinstance(self.unet._orig_mod, UNetControlNetXSModel) - ): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") else: diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index 075ef444c286..a16c6dac3e64 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -589,7 +589,7 @@ def check_inputs( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) - # Check `image` + # Check `image` and ``controlnet_conditioning_scale`` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.unet, torch._dynamo.eval_frame.OptimizedModule ) @@ -599,15 +599,6 @@ def check_inputs( and isinstance(self.unet._orig_mod, UNetControlNetXSModel) ): self.check_image(image, prompt, prompt_embeds) - else: - assert False - - # Check `controlnet_conditioning_scale` - if ( - isinstance(self.unet, UNetControlNetXSModel) - or is_compiled - and isinstance(self.unet._orig_mod, UNetControlNetXSModel) - ): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") else: diff --git a/tests/models/unets/test_models_unet_controlnetxs.py b/tests/models/unets/test_models_unet_controlnetxs.py index ddbddf6b6024..588766516ccf 100644 --- a/tests/models/unets/test_models_unet_controlnetxs.py +++ b/tests/models/unets/test_models_unet_controlnetxs.py @@ -79,7 +79,7 @@ def prepare_init_args_and_inputs_for_common(self): "num_attention_heads": 8, "upcast_attention": False, "ctrl_block_out_channels": [4, 8], - "ctrl_attention_head_dim": 8, + "ctrl_num_attention_heads": 8, "ctrl_max_norm_num_groups": 4, } inputs_dict = self.dummy_input diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py index 5a77dbf20cdc..aad49029a4bf 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py @@ -74,16 +74,20 @@ def _test_stable_diffusion_compile(in_queue, out_queue, timeout): try: _ = in_queue.get(timeout=timeout) + controlnet = ControlNetXSAddon.from_pretrained( + "UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16 + ) pipe = StableDiffusionControlNetXSPipeline.from_pretrained( - base_path="stabilityai/stable-diffusion-2-1-base", - base_kwargs={"safety_checker": None}, - addon_path="UmerHA/Testing-ConrolNetXS-SD2.1-canny", + "stabilityai/stable-diffusion-2-1-base", + controlnet=controlnet, + safety_checker=None, + torch_dtype=torch.float16, ) pipe.to("cuda") pipe.set_progress_bar_config(disable=None) - pipe.controlnet.to(memory_format=torch.channels_last) - pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True) + pipe.unet.to(memory_format=torch.channels_last) + pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) generator = torch.Generator(device="cpu").manual_seed(0) prompt = "bird" @@ -300,9 +304,11 @@ def tearDown(self): torch.cuda.empty_cache() def test_canny(self): + controlnet = ControlNetXSAddon.from_pretrained( + "UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16 + ) pipe = StableDiffusionControlNetXSPipeline.from_pretrained( - base_path="stabilityai/stable-diffusion-2-1-base", - addon_path="UmerHA/Testing-ConrolNetXS-SD2.1-canny", + "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16 ) pipe.enable_model_cpu_offload() pipe.set_progress_bar_config(disable=None) @@ -320,13 +326,15 @@ def test_canny(self): assert image.shape == (768, 512, 3) original_image = image[-3:, -3:, -1].flatten() - expected_image = np.array([0.1276, 0.1405, 0.1474, 0.1188, 0.1559, 0.1496, 0.1569, 0.1478, 0.1706]) + expected_image = np.array([0.1963, 0.229, 0.2659, 0.2109, 0.2332, 0.2827, 0.2534, 0.2422, 0.2808]) assert np.allclose(original_image, expected_image, atol=1e-04) def test_depth(self): + controlnet = ControlNetXSAddon.from_pretrained( + "UmerHA/Testing-ConrolNetXS-SD2.1-depth", torch_dtype=torch.float16 + ) pipe = StableDiffusionControlNetXSPipeline.from_pretrained( - base_path="stabilityai/stable-diffusion-2-1-base", - addon_path="UmerHA/Testing-ConrolNetXS-SD2.1-depth", + "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16 ) pipe.enable_model_cpu_offload() pipe.set_progress_bar_config(disable=None) @@ -344,7 +352,7 @@ def test_depth(self): assert image.shape == (512, 512, 3) original_image = image[-3:, -3:, -1].flatten() - expected_image = np.array([0.1101, 0.1026, 0.1212, 0.114, 0.1169, 0.1266, 0.1191, 0.1266, 0.1712]) + expected_image = np.array([0.4844, 0.4937, 0.4956, 0.4663, 0.5039, 0.5044, 0.4565, 0.4883, 0.4941]) assert np.allclose(original_image, expected_image, atol=1e-04) @require_python39_or_higher diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py index 99d7409b6465..d3c846041618 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py @@ -375,8 +375,11 @@ def tearDown(self): torch.cuda.empty_cache() def test_canny(self): + controlnet = ControlNetXSAddon.from_pretrained( + "UmerHA/Testing-ConrolNetXS-SDXL-canny", torch_dtype=torch.float16 + ) pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( - base_path="stabilityai/stable-diffusion-xl-base-1.0", addon_path="UmerHA/Testing-ConrolNetXS-SDXL-canny" + "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 ) pipe.enable_sequential_cpu_offload() pipe.set_progress_bar_config(disable=None) @@ -392,12 +395,15 @@ def test_canny(self): assert images[0].shape == (768, 512, 3) original_image = images[0, -3:, -3:, -1].flatten() - expected_image = np.array([0.4371, 0.4341, 0.4620, 0.4524, 0.4680, 0.4504, 0.4530, 0.4505, 0.4390]) + expected_image = np.array([0.3202, 0.3151, 0.3328, 0.3172, 0.337, 0.3381, 0.3378, 0.3389, 0.3224]) assert np.allclose(original_image, expected_image, atol=1e-04) def test_depth(self): + controlnet = ControlNetXSAddon.from_pretrained( + "UmerHA/Testing-ConrolNetXS-SDXL-depth", torch_dtype=torch.float16 + ) pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( - base_path="stabilityai/stable-diffusion-xl-base-1.0", addon_path="UmerHA/Testing-ConrolNetXS-SDXL-depth" + "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 ) pipe.enable_sequential_cpu_offload() pipe.set_progress_bar_config(disable=None) @@ -413,5 +419,5 @@ def test_depth(self): assert images[0].shape == (512, 512, 3) original_image = images[0, -3:, -3:, -1].flatten() - expected_image = np.array([0.4082, 0.3880, 0.2779, 0.2654, 0.327, 0.372, 0.3761, 0.3442, 0.3122]) + expected_image = np.array([0.5448, 0.5437, 0.5426, 0.5543, 0.553, 0.5475, 0.5595, 0.5602, 0.5529]) assert np.allclose(original_image, expected_image, atol=1e-04) From 1e37e8ee8784ec656d6264203be7d5f6aa834a2d Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Wed, 3 Apr 2024 04:15:00 +0200 Subject: [PATCH 64/75] Don't create UNet in init ; removed class_emb --- src/diffusers/models/controlnet_xs.py | 86 +++++++------------ .../unets/test_models_unet_controlnetxs.py | 1 - 2 files changed, 32 insertions(+), 55 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index aadac0ac83f8..ae49f1bf550d 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -23,9 +23,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, is_torch_version, logging from .autoencoders import AutoencoderKL -from .embeddings import ( - TimestepEmbedding, -) +from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unets.unet_2d_blocks import ( CrossAttnDownBlock2D, @@ -537,7 +535,6 @@ def __init__( cross_attention_dim: Union[int, Tuple[int]] = 1024, transformer_layers_per_block: Union[int, Tuple[int]] = 1, num_attention_heads: Union[int, Tuple[int]] = 8, - class_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, upcast_attention: bool = True, @@ -562,6 +559,11 @@ def __init__( "To use `time_embedding_mix` < 1, initialize `ctrl_addon` with `learn_time_embedding = True`" ) + if addition_embed_type is not None and addition_embed_type != "text_time": + raise ValueError( + "As `UNetControlNetXSModel` currently only supports StableDiffusion and StableDiffusion-XL, `addition_embed_type` must be `None` or `'text_time'`." + ) + transformer_layers_per_block = repeat_if_not_list( transformer_layers_per_block, repetitions=len(down_block_types) ) @@ -569,55 +571,39 @@ def __init__( base_num_attention_heads = repeat_if_not_list(num_attention_heads, repetitions=len(down_block_types)) ctrl_num_attention_heads = repeat_if_not_list(ctrl_num_attention_heads, repetitions=len(down_block_types)) - # Create UNet and decompose it into subblocks, which we then save - base_model = UNet2DConditionModel( - sample_size=sample_size, - down_block_types=down_block_types, - up_block_types=up_block_types, - block_out_channels=block_out_channels, - norm_num_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - transformer_layers_per_block=transformer_layers_per_block, - attention_head_dim=num_attention_heads, - use_linear_projection=True, - upcast_attention=upcast_attention, - class_embed_type=class_embed_type, - addition_embed_type=addition_embed_type, - time_cond_proj_dim=time_cond_proj_dim, - projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, - addition_time_embed_dim=addition_time_embed_dim, - ) - self.in_channels = 4 - time_embed_input_dim = block_out_channels[0] - time_embed_dim = block_out_channels[0] * 4 - - self.base_time_proj = base_model.time_proj - self.base_time_embedding = base_model.time_embedding - self.base_class_embedding = base_model.class_embedding - self.base_add_time_proj = base_model.add_time_proj if hasattr(base_model, "add_time_proj") else None - self.base_add_embedding = base_model.add_embedding if hasattr(base_model, "add_embedding") else None - - self.base_conv_in = base_model.conv_in - self.base_conv_norm_out = base_model.conv_norm_out - self.base_conv_act = base_model.conv_act - self.base_conv_out = base_model.conv_out - + # # Input + self.base_conv_in = nn.Conv2d(4, block_out_channels[0], kernel_size=3, padding=1) self.controlnet_cond_embedding = ControlNetConditioningEmbedding( conditioning_embedding_channels=ctrl_block_out_channels[0], block_out_channels=ctrl_conditioning_embedding_out_channels, conditioning_channels=ctrl_conditioning_channels, ) self.ctrl_conv_in = nn.Conv2d(4, ctrl_block_out_channels[0], kernel_size=3, padding=1) - self.ctrl_time_embedding = TimestepEmbedding(in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim) - self.control_to_base_for_conv_in = make_zero_conv(ctrl_block_out_channels[0], block_out_channels[0]) - down_blocks = [] - up_blocks = [] + # # Time + time_embed_input_dim = block_out_channels[0] + time_embed_dim = block_out_channels[0] * 4 + + self.base_time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0) + self.base_time_embedding = TimestepEmbedding( + time_embed_input_dim, + time_embed_dim, + cond_proj_dim=time_cond_proj_dim, + ) + self.ctrl_time_embedding = TimestepEmbedding(in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim) + + if addition_embed_type is None: + self.base_add_time_proj = None + self.base_add_embedding = None + else: + self.base_add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.base_add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) # # Create down blocks + down_blocks = [] base_out_channels = block_out_channels[0] ctrl_out_channels = ctrl_block_out_channels[0] for i, down_block_type in enumerate(down_block_types): @@ -659,6 +645,7 @@ def __init__( ) # # Create up blocks + up_blocks = [] rev_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) rev_num_attention_heads = list(reversed(base_num_attention_heads)) rev_cross_attention_dim = list(reversed(cross_attention_dim)) @@ -702,6 +689,10 @@ def __init__( self.down_blocks = nn.ModuleList(down_blocks) self.up_blocks = nn.ModuleList(up_blocks) + self.base_conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups) + self.base_conv_act = nn.SiLU() + self.base_conv_out = nn.Conv2d(block_out_channels[0], 4, kernel_size=3, padding=1) + @classmethod def from_unet( cls, @@ -748,7 +739,6 @@ def from_unet( "norm_num_groups", "cross_attention_dim", "transformer_layers_per_block", - "class_embed_type", "addition_embed_type", "addition_time_embed_dim", "upcast_attention", @@ -786,7 +776,6 @@ def from_unet( getattr(model, "base_" + m).load_state_dict(getattr(unet, m).state_dict()) optional_modules_from_unet = [ - "class_embedding", "add_time_proj", "add_embedding", ] @@ -827,7 +816,6 @@ def freeze_unet_params(self) -> None: base_parts = [ "base_time_proj", "base_time_embedding", - "base_class_embedding", "base_add_time_proj", "base_add_embedding", "base_conv_in", @@ -956,16 +944,6 @@ def forward( # added time & text embeddings aug_emb = None - if self.base_class_embedding is not None: - if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") - - if self.config.class_embed_type == "timestep": - class_labels = self.base_time_proj(class_labels) - - class_emb = self.base_class_embedding(class_labels).to(dtype=self.dtype) - temb = temb + class_emb - if self.config.addition_embed_type is None: pass elif self.config.addition_embed_type == "text_time": diff --git a/tests/models/unets/test_models_unet_controlnetxs.py b/tests/models/unets/test_models_unet_controlnetxs.py index 588766516ccf..d34e9fce3c74 100644 --- a/tests/models/unets/test_models_unet_controlnetxs.py +++ b/tests/models/unets/test_models_unet_controlnetxs.py @@ -205,7 +205,6 @@ def assert_unfrozen(module): assert_frozen(m) optional_modules_from_unet = [ - model.base_class_embedding, model.base_add_time_proj, model.base_add_embedding, ] From 993103060f2540ac0343d125f6c0f298b1d5169e Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Wed, 3 Apr 2024 19:40:45 +0200 Subject: [PATCH 65/75] Incorporated review feedback - Deleted get_base_pipeline / get_controlnet_addon for pipes - Pipes inherit from StableDiffusionXLPipeline - Made module dicts for cnxs-addon's down/mid/up classes - Added support for qkv fusion and freeU --- src/diffusers/models/controlnet_xs.py | 251 +++++++++++++++--- .../controlnet_xs/pipeline_controlnet_xs.py | 50 +--- .../pipeline_controlnet_xs_sd_xl.py | 46 +--- .../unets/test_models_unet_controlnetxs.py | 22 +- .../controlnet_xs/test_controlnetxs.py | 7 +- tests/pipelines/test_pipelines_common.py | 3 +- 6 files changed, 244 insertions(+), 135 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index ae49f1bf550d..9f277e286ffb 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,6 +22,8 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, is_torch_version, logging +from ..utils.torch_utils import apply_freeu +from .attention_processor import Attention, AttentionProcessor from .autoencoders import AutoencoderKL from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin @@ -54,6 +56,33 @@ class ControlNetXSOutput(BaseOutput): sample: FloatTensor = None +class ControlNetXSAddonDownBlockComponents(nn.Module): + """Components that together with corresponding components from the base model will form a `ControlNetXSCrossAttnDownBlock2D`""" + def __init__(self, resnets: nn.ModuleList, base_to_ctrl:nn.ModuleList, ctrl_to_base:nn.ModuleList, attentions: Optional[nn.ModuleList] = None,downsampler: Optional[nn.Conv2d] = None): + super().__init__() + self.resnets = resnets + self.base_to_ctrl = base_to_ctrl + self.ctrl_to_base = ctrl_to_base + self.attentions = attentions + self.downsamplers = downsampler + + +class ControlNetXSAddonMidBlockComponents(nn.Module): + """Components that together with corresponding components from the base model will form a `ControlNetXSCrossAttnMidBlock2D`""" + def __init__(self, midblock: UNetMidBlock2DCrossAttn, base_to_ctrl:nn.ModuleList, ctrl_to_base:nn.ModuleList): + super().__init__() + self.midblock = midblock + self.base_to_ctrl = base_to_ctrl + self.ctrl_to_base = ctrl_to_base + + +class ControlNetXSAddonUpBlockComponents(nn.Module): + """Components that together with corresponding components from the base model will form a `ControlNetXSCrossAttnUpBlock2D`""" + def __init__(self, ctrl_to_base:nn.ModuleList): + super().__init__() + self.ctrl_to_base = ctrl_to_base + + # copied from diffusers.models.controlnet.ControlNetConditioningEmbedding class ControlNetConditioningEmbedding(nn.Module): """ @@ -230,7 +259,7 @@ def __init__( is_final_block = i == len(down_block_types) - 1 self.down_blocks.append( - ControlNetXSAddon.get_down_block( + self.get_down_block( base_in_channels=base_in_channels, base_out_channels=base_out_channels, ctrl_in_channels=ctrl_in_channels, @@ -247,7 +276,7 @@ def __init__( ) # mid - self.mid_block = ControlNetXSAddon.get_mid_block( + self.mid_block = self.get_mid_block( base_channels=base_block_out_channels[-1], ctrl_channels=block_out_channels[-1], temb_channels=time_embedding_dim, @@ -275,7 +304,7 @@ def __init__( ctrl_skip_channels_ = [ctrl_skip_channels.pop() for _ in range(3)] self.up_connections.append( - ControlNetXSAddon.get_up_connections( + self.get_up_connections( out_channels=base_out_channels, prev_output_channel=prev_base_output_channel, ctrl_skip_channels=ctrl_skip_channels_, @@ -359,19 +388,18 @@ def get_down_block( else: downsamplers = None - module_dict = nn.ModuleDict( - { - "resnets": nn.ModuleList(resnets), - "base_to_ctrl": nn.ModuleList(base_to_ctrl), - "ctrl_to_base": nn.ModuleList(ctrl_to_base), - } + down_block_components = ControlNetXSAddonDownBlockComponents( + resnets=nn.ModuleList(resnets), + base_to_ctrl=nn.ModuleList(base_to_ctrl), + ctrl_to_base=nn.ModuleList(ctrl_to_base) ) + if has_crossattn: - module_dict["attentions"] = nn.ModuleList(attentions) + down_block_components.attentions = nn.ModuleList(attentions) if downsamplers is not None: - module_dict["downsamplers"] = downsamplers + down_block_components.downsamplers = downsamplers - return module_dict + return down_block_components @staticmethod def get_mid_block( @@ -405,7 +433,7 @@ def get_mid_block( # Addition requires change in number of channels ctrl_to_base = make_zero_conv(ctrl_channels, base_channels) - return nn.ModuleDict({"base_to_ctrl": base_to_ctrl, "midblock": midblock, "ctrl_to_base": ctrl_to_base}) + return ControlNetXSAddonMidBlockComponents(base_to_ctrl=base_to_ctrl, midblock=midblock, ctrl_to_base=ctrl_to_base) @staticmethod def get_up_connections( @@ -419,7 +447,7 @@ def get_up_connections( resnet_in_channels = prev_output_channel if i == 0 else out_channels ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels)) - return nn.ModuleList(ctrl_to_base) + return ControlNetXSAddonUpBlockComponents(ctrl_to_base=nn.ModuleList(ctrl_to_base)) @classmethod def from_unet( @@ -677,6 +705,7 @@ def __init__( prev_output_channel=prev_output_channel, ctrl_skip_channels=ctrl_skip_channels_, temb_channels=time_embed_dim, + resolution_idx=i, has_crossattn=has_crossattn, transformer_layers_per_block=rev_transformer_layers_per_block[-1], num_attention_heads=rev_num_attention_heads[-1], @@ -845,6 +874,138 @@ def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value + # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stage blocks where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that + are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + for i, upsample_block in enumerate(self.up_blocks): + setattr(upsample_block, "s1", s1) + setattr(upsample_block, "s2", s2) + setattr(upsample_block, "b1", b1) + setattr(upsample_block, "b2", b2) + + # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel + def disable_freeu(self): + """Disables the FreeU mechanism.""" + freeu_keys = {"s1", "s2", "b1", "b2"} + for i, upsample_block in enumerate(self.up_blocks): + for k in freeu_keys: + if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: + setattr(upsample_block, k, None) + + # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, + key, value) are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + def forward( self, sample: FloatTensor, @@ -1162,7 +1323,7 @@ def __init__( self.gradient_checkpointing = False @classmethod - def from_modules(cls, base_downblock: CrossAttnDownBlock2D, ctrl_downblock: nn.ModuleDict): + def from_modules(cls, base_downblock: CrossAttnDownBlock2D, ctrl_downblock: ControlNetXSAddonDownBlockComponents): # get params def get_first_cross_attention(block): return block.attentions[0].transformer_blocks[0].attn2 @@ -1170,11 +1331,11 @@ def get_first_cross_attention(block): base_in_channels = base_downblock.resnets[0].in_channels base_out_channels = base_downblock.resnets[0].out_channels ctrl_in_channels = ( - ctrl_downblock["resnets"][0].in_channels - base_in_channels + ctrl_downblock.resnets[0].in_channels - base_in_channels ) # base channels are concatted to ctrl channels in init - ctrl_out_channels = ctrl_downblock["resnets"][0].out_channels + ctrl_out_channels = ctrl_downblock.resnets[0].out_channels temb_channels = base_downblock.resnets[0].time_emb_proj.in_features - num_groups = ctrl_downblock["resnets"][0].norm1.num_groups + num_groups = ctrl_downblock.resnets[0].norm1.num_groups if hasattr(base_downblock, "attentions"): has_crossattn = True transformer_layers_per_block = len(base_downblock.attentions[0].transformer_blocks) @@ -1210,15 +1371,15 @@ def get_first_cross_attention(block): # # load weights model.base_resnets.load_state_dict(base_downblock.resnets.state_dict()) - model.ctrl_resnets.load_state_dict(ctrl_downblock["resnets"].state_dict()) + model.ctrl_resnets.load_state_dict(ctrl_downblock.resnets.state_dict()) if has_crossattn: model.base_attentions.load_state_dict(base_downblock.attentions.state_dict()) - model.ctrl_attentions.load_state_dict(ctrl_downblock["attentions"].state_dict()) + model.ctrl_attentions.load_state_dict(ctrl_downblock.attentions.state_dict()) if add_downsample: model.base_downsamplers.load_state_dict(base_downblock.downsamplers[0].state_dict()) - model.ctrl_downsamplers.load_state_dict(ctrl_downblock["downsamplers"].state_dict()) - model.base_to_ctrl.load_state_dict(ctrl_downblock["base_to_ctrl"].state_dict()) - model.ctrl_to_base.load_state_dict(ctrl_downblock["ctrl_to_base"].state_dict()) + model.ctrl_downsamplers.load_state_dict(ctrl_downblock.downsamplers.state_dict()) + model.base_to_ctrl.load_state_dict(ctrl_downblock.base_to_ctrl.state_dict()) + model.ctrl_to_base.load_state_dict(ctrl_downblock.ctrl_to_base.state_dict()) return model @@ -1404,11 +1565,11 @@ def __init__( def from_modules( cls, base_midblock: UNetMidBlock2DCrossAttn, - ctrl_midblock_dict: nn.ModuleDict, + ctrl_midblock: ControlNetXSAddonMidBlockComponents, ): - base_to_ctrl = ctrl_midblock_dict["base_to_ctrl"] - ctrl_to_base = ctrl_midblock_dict["ctrl_to_base"] - ctrl_midblock = ctrl_midblock_dict["midblock"] + base_to_ctrl = ctrl_midblock.base_to_ctrl + ctrl_to_base = ctrl_midblock.ctrl_to_base + ctrl_midblock = ctrl_midblock.midblock # get params def get_first_cross_attention(midblock): @@ -1500,6 +1661,7 @@ def __init__( prev_output_channel: int, ctrl_skip_channels: List[int], temb_channels: int, + resolution_idx: Optional[int] = None, has_crossattn=True, transformer_layers_per_block: int = 1, num_attention_heads: int = 1, @@ -1557,9 +1719,12 @@ def __init__( self.upsamplers = None self.gradient_checkpointing = False + self.resolution_idx = resolution_idx @classmethod - def from_modules(cls, base_upblock: CrossAttnUpBlock2D, ctrl_to_base_skip_connections: nn.ModuleList): + def from_modules(cls, base_upblock: CrossAttnUpBlock2D, ctrl_upblock: ControlNetXSAddonUpBlockComponents): + ctrl_to_base_skip_connections = ctrl_upblock.ctrl_to_base + # get params def get_first_cross_attention(block): return block.attentions[0].transformer_blocks[0].attn2 @@ -1569,6 +1734,7 @@ def get_first_cross_attention(block): prev_output_channels = base_upblock.resnets[0].in_channels - out_channels ctrl_skip_channelss = [c.in_channels for c in ctrl_to_base_skip_connections] temb_channels = base_upblock.resnets[0].time_emb_proj.in_features + resolution_idx=base_upblock.resolution_idx if hasattr(base_upblock, "attentions"): has_crossattn = True transformer_layers_per_block = len(base_upblock.attentions[0].transformer_blocks) @@ -1590,6 +1756,7 @@ def get_first_cross_attention(block): prev_output_channel=prev_output_channels, ctrl_skip_channels=ctrl_skip_channelss, temb_channels=temb_channels, + resolution_idx=resolution_idx, has_crossattn=has_crossattn, transformer_layers_per_block=transformer_layers_per_block, num_attention_heads=num_attention_heads, @@ -1642,7 +1809,14 @@ def forward( if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - # In ControlNet-XS, the last resnet/attention and the upsampler are treated as a group. + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + # In ControlNet-XS, the last resnet/attention and the upsampler are treated together as one group. # So we separate them to pass information from ctrl to base correctly. if self.upsamplers is None: resnets_without_upsampler = self.resnets @@ -1662,6 +1836,21 @@ def custom_forward(*inputs): return custom_forward + def maybe_apply_freeu_to_subblock(hidden_states, res_h_base): + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + return apply_freeu( + self.resolution_idx, + hidden_states, + res_h_base, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + else: + return hidden_states, res_h_base + for resnet, attn, c2b, res_h_base, res_h_ctrl in zip( resnets_without_upsampler, attn_without_upsampler, @@ -1672,6 +1861,7 @@ def custom_forward(*inputs): if do_control: hidden_states += c2b(res_h_ctrl) * conditioning_scale + hidden_states, res_h_base = maybe_apply_freeu_to_subblock(hidden_states, res_h_base) hidden_states = torch.cat([hidden_states, res_h_base], dim=1) if self.training and self.gradient_checkpointing: @@ -1701,6 +1891,7 @@ def custom_forward(*inputs): res_h_ctrl = res_hidden_states_tuple_ctrl[0] if do_control: hidden_states += c2b(res_h_ctrl) * conditioning_scale + hidden_states, res_h_base = maybe_apply_freeu_to_subblock(hidden_states, res_h_base) hidden_states = torch.cat([hidden_states, res_h_base], dim=1) hidden_states = resnet_with_upsampler(hidden_states, temb) diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index 85ed6f1be173..c6e8085f432e 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -35,8 +35,7 @@ unscale_lora_layers, ) from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor -from ..pipeline_utils import DiffusionPipeline -from ..stable_diffusion import StableDiffusionPipeline +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -90,7 +89,7 @@ class StableDiffusionControlNetXSPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for text-to-image generation using Stable Diffusion with ControlNet-XS guidance. @@ -189,49 +188,6 @@ def __init__( ) self.register_to_config(requires_safety_checker=requires_safety_checker) - def get_base_pipeline(self): - """Get underlying `StableDiffusionPipeline` without the `ControlNetXSAddon` model.""" - components = {k: v for k, v in self.components.items() if k != "controlnet"} - components["unet"] = self.components["controlnet"].base_model - return StableDiffusionPipeline(**components) - - def get_controlnet_addon(self): - """Get the `ControlNetXSAddon` model.""" - return self.components["controlnet"].ctrl_addon - - # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() - - # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() - - # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_tiling - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - self.vae.enable_tiling() - - # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_tiling - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_tiling() - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index a16c6dac3e64..870e32d2f522 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -49,7 +49,6 @@ ) from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline -from ..stable_diffusion_xl import StableDiffusionXLPipeline from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput @@ -215,49 +214,6 @@ def __init__( self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) - def get_base_pipeline(self): - """Get underlying `StableDiffusionXLPipeline` without the `ControlNetXSAddon` model.""" - components = {k: v for k, v in self.components.items() if k != "controlnet"} - components["unet"] = self.components["controlnet"].base_model - return StableDiffusionXLPipeline(**components) - - def get_controlnet_addon(self): - """Get the `ControlNetXSAddon` model.""" - return self.components["controlnet"].ctrl_addon - - # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() - - # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() - - # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_tiling - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - self.vae.enable_tiling() - - # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_tiling - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_tiling() - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( self, diff --git a/tests/models/unets/test_models_unet_controlnetxs.py b/tests/models/unets/test_models_unet_controlnetxs.py index d34e9fce3c74..eb93ce8cfdf8 100644 --- a/tests/models/unets/test_models_unet_controlnetxs.py +++ b/tests/models/unets/test_models_unet_controlnetxs.py @@ -164,21 +164,21 @@ def assert_equal_weights(module, weight_dict_prefix): # down blocks assert len(controlnet.down_blocks) == len(model.down_blocks) for i, d in enumerate(controlnet.down_blocks): - assert_equal_weights(d["resnets"], f"down_blocks.{i}.ctrl_resnets") - assert_equal_weights(d["base_to_ctrl"], f"down_blocks.{i}.base_to_ctrl") - assert_equal_weights(d["ctrl_to_base"], f"down_blocks.{i}.ctrl_to_base") - if "attentions" in d: - assert_equal_weights(d["attentions"], f"down_blocks.{i}.ctrl_attentions") - if "downsamplers" in d: - assert_equal_weights(d["downsamplers"], f"down_blocks.{i}.ctrl_downsamplers") + assert_equal_weights(d.resnets, f"down_blocks.{i}.ctrl_resnets") + assert_equal_weights(d.base_to_ctrl, f"down_blocks.{i}.base_to_ctrl") + assert_equal_weights(d.ctrl_to_base, f"down_blocks.{i}.ctrl_to_base") + if d.attentions is not None: + assert_equal_weights(d.attentions, f"down_blocks.{i}.ctrl_attentions") + if d.downsamplers is not None: + assert_equal_weights(d.downsamplers, f"down_blocks.{i}.ctrl_downsamplers") # mid block - assert_equal_weights(controlnet.mid_block["base_to_ctrl"], "mid_block.base_to_ctrl") - assert_equal_weights(controlnet.mid_block["midblock"], "mid_block.ctrl_midblock") - assert_equal_weights(controlnet.mid_block["ctrl_to_base"], "mid_block.ctrl_to_base") + assert_equal_weights(controlnet.mid_block.base_to_ctrl, "mid_block.base_to_ctrl") + assert_equal_weights(controlnet.mid_block.midblock, "mid_block.ctrl_midblock") + assert_equal_weights(controlnet.mid_block.ctrl_to_base, "mid_block.ctrl_to_base") # up blocks assert len(controlnet.up_connections) == len(model.up_blocks) for i, u in enumerate(controlnet.up_connections): - assert_equal_weights(u, f"up_blocks.{i}.ctrl_to_base") + assert_equal_weights(u.ctrl_to_base, f"up_blocks.{i}.ctrl_to_base") def test_freeze_unet(self): def assert_frozen(module): diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py index aad49029a4bf..e91f9d8d313a 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py @@ -62,6 +62,7 @@ PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, + SDFunctionTesterMixin, ) @@ -116,7 +117,11 @@ def _test_stable_diffusion_compile(in_queue, out_queue, timeout): class ControlNetXSPipelineFastTests( - PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase + PipelineLatentTesterMixin, + PipelineKarrasSchedulerTesterMixin, + PipelineTesterMixin, + SDFunctionTesterMixin, + unittest.TestCase, ): pipeline_class = StableDiffusionControlNetXSPipeline params = TEXT_TO_IMAGE_PARAMS diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 13007a2aa1f7..d5a500f1c692 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -30,6 +30,7 @@ ) from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import IPAdapterMixin +from diffusers.models.controlnet_xs import UNetControlNetXSModel from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel from diffusers.models.unets.unet_i2vgen_xl import I2VGenXLUNet from diffusers.models.unets.unet_motion_model import UNetMotionModel @@ -1327,7 +1328,7 @@ def test_StableDiffusionMixin_component(self): self.assertTrue(hasattr(pipe, "vae") and isinstance(pipe.vae, (AutoencoderKL, AutoencoderTiny))) self.assertTrue( hasattr(pipe, "unet") - and isinstance(pipe.unet, (UNet2DConditionModel, UNet3DConditionModel, I2VGenXLUNet, UNetMotionModel)) + and isinstance(pipe.unet, (UNet2DConditionModel, UNet3DConditionModel, I2VGenXLUNet, UNetMotionModel, UNetControlNetXSModel)) ) From 18ded9d1b68a398fe4ca42af990dfe884506832a Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Wed, 3 Apr 2024 20:21:09 +0200 Subject: [PATCH 66/75] Make style, quality, fix-copies --- Pipfile | 11 +++ src/diffusers/models/controlnet_xs.py | 94 ++++++++++++------- .../controlnet_xs/pipeline_controlnet_xs.py | 4 +- .../pipeline_controlnet_xs_sd_xl.py | 4 +- tests/pipelines/test_pipelines_common.py | 5 +- 5 files changed, 77 insertions(+), 41 deletions(-) create mode 100644 Pipfile diff --git a/Pipfile b/Pipfile new file mode 100644 index 000000000000..0757494bb360 --- /dev/null +++ b/Pipfile @@ -0,0 +1,11 @@ +[[source]] +url = "https://pypi.org/simple" +verify_ssl = true +name = "pypi" + +[packages] + +[dev-packages] + +[requires] +python_version = "3.11" diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 9f277e286ffb..d16bc42e663c 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -49,16 +49,25 @@ class ControlNetXSOutput(BaseOutput): Args: sample (`FloatTensor` of shape `(batch_size, num_channels, height, width)`): - The output of the `UNetControlNetXSModel`. Unlike `ControlNetOutput` this is NOT to be added to the base model - output, but is already the final output. + The output of the `UNetControlNetXSModel`. Unlike `ControlNetOutput` this is NOT to be added to the base + model output, but is already the final output. """ sample: FloatTensor = None class ControlNetXSAddonDownBlockComponents(nn.Module): - """Components that together with corresponding components from the base model will form a `ControlNetXSCrossAttnDownBlock2D`""" - def __init__(self, resnets: nn.ModuleList, base_to_ctrl:nn.ModuleList, ctrl_to_base:nn.ModuleList, attentions: Optional[nn.ModuleList] = None,downsampler: Optional[nn.Conv2d] = None): + """Components that together with corresponding components from the base model will form a + `ControlNetXSCrossAttnDownBlock2D`""" + + def __init__( + self, + resnets: nn.ModuleList, + base_to_ctrl: nn.ModuleList, + ctrl_to_base: nn.ModuleList, + attentions: Optional[nn.ModuleList] = None, + downsampler: Optional[nn.Conv2d] = None, + ): super().__init__() self.resnets = resnets self.base_to_ctrl = base_to_ctrl @@ -68,8 +77,10 @@ def __init__(self, resnets: nn.ModuleList, base_to_ctrl:nn.ModuleList, ctrl_to_b class ControlNetXSAddonMidBlockComponents(nn.Module): - """Components that together with corresponding components from the base model will form a `ControlNetXSCrossAttnMidBlock2D`""" - def __init__(self, midblock: UNetMidBlock2DCrossAttn, base_to_ctrl:nn.ModuleList, ctrl_to_base:nn.ModuleList): + """Components that together with corresponding components from the base model will form a + `ControlNetXSCrossAttnMidBlock2D`""" + + def __init__(self, midblock: UNetMidBlock2DCrossAttn, base_to_ctrl: nn.ModuleList, ctrl_to_base: nn.ModuleList): super().__init__() self.midblock = midblock self.base_to_ctrl = base_to_ctrl @@ -78,7 +89,8 @@ def __init__(self, midblock: UNetMidBlock2DCrossAttn, base_to_ctrl:nn.ModuleList class ControlNetXSAddonUpBlockComponents(nn.Module): """Components that together with corresponding components from the base model will form a `ControlNetXSCrossAttnUpBlock2D`""" - def __init__(self, ctrl_to_base:nn.ModuleList): + + def __init__(self, ctrl_to_base: nn.ModuleList): super().__init__() self.ctrl_to_base = ctrl_to_base @@ -131,13 +143,14 @@ def forward(self, conditioning): class ControlNetXSAddon(ModelMixin, ConfigMixin): r""" - A `ControlNetXSAddon` model. To use it, pass it into a `ControlNetXSModel` (together with a `UNet2DConditionModel` base model). + A `ControlNetXSAddon` model. To use it, pass it into a `ControlNetXSModel` (together with a `UNet2DConditionModel` + base model). This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). - Like `ControlNetXSModel`, `ControlNetXSAddon` is compatible with StableDiffusion and StableDiffusion-XL. - It's default parameters are compatible with StableDiffusion. + Like `ControlNetXSModel`, `ControlNetXSAddon` is compatible with StableDiffusion and StableDiffusion-XL. It's + default parameters are compatible with StableDiffusion. Parameters: conditioning_channels (`int`, defaults to 3): @@ -147,12 +160,11 @@ class ControlNetXSAddon(ModelMixin, ConfigMixin): conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): The tuple of output channels for each block in the `controlnet_cond_embedding` layer. time_embedding_mix (`float`, defaults to 1.0): - If 0, then only the control addon's time embedding is used. - If 1, then only the base unet's time embedding is used. - Otherwise, both are combined. + If 0, then only the control addon's time embedding is used. If 1, then only the base unet's time embedding + is used. Otherwise, both are combined. learn_time_embedding (`bool`, defaults to `False`): - Whether a time embedding should be learned. If yes, `ControlNetXSModel` will combine the time embeddings of the base model and the addon. - If no, `ControlNetXSModel` will use the base model's time embedding. + Whether a time embedding should be learned. If yes, `ControlNetXSModel` will combine the time embeddings of + the base model and the addon. If no, `ControlNetXSModel` will use the base model's time embedding. num_attention_heads (`list[int]`, defaults to `[4]`): The number of attention heads. block_out_channels (`list[int]`, defaults to `[4, 8, 16, 16]`): @@ -171,7 +183,8 @@ class ControlNetXSAddon(ModelMixin, ConfigMixin): upcast_attention (`bool`, defaults to `True`): Whether the attention computation should always be upcasted. max_norm_num_groups (`int`, defaults to 32): - Maximum number of groups in group normal. The actual number will the the largest divisor of the respective channels, that is <= max_norm_num_groups. + Maximum number of groups in group normal. The actual number will the the largest divisor of the respective + channels, that is <= max_norm_num_groups. """ @register_to_config @@ -391,7 +404,7 @@ def get_down_block( down_block_components = ControlNetXSAddonDownBlockComponents( resnets=nn.ModuleList(resnets), base_to_ctrl=nn.ModuleList(base_to_ctrl), - ctrl_to_base=nn.ModuleList(ctrl_to_base) + ctrl_to_base=nn.ModuleList(ctrl_to_base), ) if has_crossattn: @@ -433,7 +446,9 @@ def get_mid_block( # Addition requires change in number of channels ctrl_to_base = make_zero_conv(ctrl_channels, base_channels) - return ControlNetXSAddonMidBlockComponents(base_to_ctrl=base_to_ctrl, midblock=midblock, ctrl_to_base=ctrl_to_base) + return ControlNetXSAddonMidBlockComponents( + base_to_ctrl=base_to_ctrl, midblock=midblock, ctrl_to_base=ctrl_to_base + ) @staticmethod def get_up_connections( @@ -469,18 +484,18 @@ def from_unet( unet (`UNet2DConditionModel`): The UNet model we want to control. The dimensions of the ControlNetXSAddon will be adapted to it. size_ratio (float, *optional*, defaults to `None`): - When given, block_out_channels is set to a fraction of the base model's block_out_channels. - Either this or `block_out_channels` must be given. + When given, block_out_channels is set to a fraction of the base model's block_out_channels. Either this + or `block_out_channels` must be given. block_out_channels (`List[int]`, *optional*, defaults to `None`): Down blocks output channels in control model. Either this or `size_ratio` must be given. num_attention_heads (`List[int]`, *optional*, defaults to `None`): - The dimension of the attention heads. The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. + The dimension of the attention heads. The naming seems a bit confusing and it is, see + https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. learn_time_embedding (`bool`, defaults to `False`): Whether the `ControlNetXSAddon` should learn a time embedding. time_embedding_mix (`float`, defaults to 1.0): - If 0, then only the control addon's time embedding is used. - If 1, then only the base unet's time embedding is used. - Otherwise, both are combined. + If 0, then only the control addon's time embedding is used. If 1, then only the base unet's time + embedding is used. Otherwise, both are combined. conditioning_channels (`int`, defaults to 3): Number of channels of conditioning input (e.g. an image) conditioning_channel_order (`str`, defaults to `"rgb"`): @@ -538,10 +553,11 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). - `UNetControlNetXSModel` is compatible with StableDiffusion and StableDiffusion-XL. - It's default parameters are compatible with StableDiffusion. + `UNetControlNetXSModel` is compatible with StableDiffusion and StableDiffusion-XL. It's default parameters are + compatible with StableDiffusion. - It's parameters are either passed to the underlying `UNet2DConditionModel` or used exactly like in `ControlNetXSAddon` . See their documentation for details. + It's parameters are either passed to the underlying `UNet2DConditionModel` or used exactly like in + `ControlNetXSAddon` . See their documentation for details. """ _supports_gradient_checkpointing = True @@ -739,11 +755,13 @@ def from_unet( unet (`UNet2DConditionModel`): The UNet model we want to control. controlnet (`ControlNetXSAddon`): - The ConntrolNet-XS addon with which the UNet will be fused. If none is given, a new ConntrolNet-XS addon will be created. + The ConntrolNet-XS addon with which the UNet will be fused. If none is given, a new ConntrolNet-XS + addon will be created. size_ratio (float, *optional*, defaults to `None`): Used to contruct the controlnet if none is given. See [`ControlNetXSAddon.from_unet`] for details. ctrl_block_out_channels (`List[int]`, *optional*, defaults to `None`): - Used to contruct the controlnet if none is given. See [`ControlNetXSAddon.from_unet`] for details, where this parameter is called `block_out_channels`. + Used to contruct the controlnet if none is given. See [`ControlNetXSAddon.from_unet`] for details, + where this parameter is called `block_out_channels`. time_embedding_mix (`float`, *optional*, defaults to None): Used to contruct the controlnet if none is given. See [`ControlNetXSAddon.from_unet`] for details. ctrl_optional_kwargs (`Dict`, *optional*, defaults to `None`): @@ -836,7 +854,8 @@ def from_unet( return model def freeze_unet_params(self) -> None: - """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine tuning.""" + """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine + tuning.""" # Freeze everything for param in self.parameters(): param.requires_grad = True @@ -971,8 +990,8 @@ def disable_freeu(self): # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel def fuse_qkv_projections(self): """ - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, - key, value) are fused. For cross-attention modules, key and value projection matrices are fused. + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. @@ -1384,7 +1403,8 @@ def get_first_cross_attention(block): return model def freeze_base_params(self) -> None: - """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine tuning.""" + """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine + tuning.""" # Unfreeze everything for param in self.parameters(): param.requires_grad = True @@ -1607,7 +1627,8 @@ def get_first_cross_attention(midblock): return model def freeze_base_params(self) -> None: - """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine tuning.""" + """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine + tuning.""" # Unfreeze everything for param in self.parameters(): param.requires_grad = True @@ -1734,7 +1755,7 @@ def get_first_cross_attention(block): prev_output_channels = base_upblock.resnets[0].in_channels - out_channels ctrl_skip_channelss = [c.in_channels for c in ctrl_to_base_skip_connections] temb_channels = base_upblock.resnets[0].time_emb_proj.in_features - resolution_idx=base_upblock.resolution_idx + resolution_idx = base_upblock.resolution_idx if hasattr(base_upblock, "attentions"): has_crossattn = True transformer_layers_per_block = len(base_upblock.attentions[0].transformer_blocks) @@ -1776,7 +1797,8 @@ def get_first_cross_attention(block): return model def freeze_base_params(self) -> None: - """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine tuning.""" + """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine + tuning.""" # Unfreeze everything for param in self.parameters(): param.requires_grad = True diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index c6e8085f432e..dddd1479cd24 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -60,7 +60,7 @@ >>> # download an image >>> image = load_image( - ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" + ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" ... ) >>> # initialize the models and pipeline @@ -667,7 +667,7 @@ def __call__( Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. - image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`, + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index 870e32d2f522..df81a0073423 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -86,7 +86,7 @@ ... "UmerHA/Testing-ConrolNetXS-SDXL-canny", torch_dtype=torch.float16 ... ) >>> pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( - ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, , torch_dtype=torch.float16 + ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 ... ) >>> pipe.enable_model_cpu_offload() @@ -765,7 +765,7 @@ def __call__( prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders. - image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`, + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 0692792e9636..6ec73626f9fb 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1634,7 +1634,10 @@ def test_StableDiffusionMixin_component(self): self.assertTrue(hasattr(pipe, "vae") and isinstance(pipe.vae, (AutoencoderKL, AutoencoderTiny))) self.assertTrue( hasattr(pipe, "unet") - and isinstance(pipe.unet, (UNet2DConditionModel, UNet3DConditionModel, I2VGenXLUNet, UNetMotionModel, UNetControlNetXSModel)) + and isinstance( + pipe.unet, + (UNet2DConditionModel, UNet3DConditionModel, I2VGenXLUNet, UNetMotionModel, UNetControlNetXSModel), + ) ) From 25b747322a1f0ae1048d3b95453fd9bc3feeaeea Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Wed, 3 Apr 2024 22:09:51 +0200 Subject: [PATCH 67/75] Implemented review feedback --- src/diffusers/models/controlnet_xs.py | 47 ++++++++-------- .../controlnet_xs/pipeline_controlnet_xs.py | 53 ++++--------------- 2 files changed, 33 insertions(+), 67 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index d16bc42e663c..a4a4da26c699 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -212,8 +212,6 @@ def __init__( ): super().__init__() - self.sample_size = sample_size - time_embedding_input_dim = base_block_out_channels[0] time_embedding_dim = base_block_out_channels[0] * 4 @@ -226,12 +224,13 @@ def __init__( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) - transformer_layers_per_block = repeat_if_not_list( - transformer_layers_per_block, repetitions=len(down_block_types) - ) - cross_attention_dim = repeat_if_not_list(cross_attention_dim, repetitions=len(down_block_types)) + if not isinstance(transformer_layers_per_block, (list, tuple)): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + if not isinstance(cross_attention_dim, (list, tuple)): + cross_attention_dim = [cross_attention_dim] * len(down_block_types) # see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why `ControlNetXSAddon` takes `num_attention_heads` instead of `attention_head_dim` - num_attention_heads = repeat_if_not_list(num_attention_heads, repetitions=len(down_block_types)) + if not isinstance(num_attention_heads, (list, tuple)): + num_attention_heads = [num_attention_heads] * len(down_block_types) if len(num_attention_heads) != len(down_block_types): raise ValueError( @@ -251,8 +250,6 @@ def __init__( else: self.time_embedding = None - self.time_embed_act = None - self.down_blocks = nn.ModuleList([]) self.up_connections = nn.ModuleList([]) @@ -272,7 +269,7 @@ def __init__( is_final_block = i == len(down_block_types) - 1 self.down_blocks.append( - self.get_down_block( + self.get_down_block_addon( base_in_channels=base_in_channels, base_out_channels=base_out_channels, ctrl_in_channels=ctrl_in_channels, @@ -289,7 +286,7 @@ def __init__( ) # mid - self.mid_block = self.get_mid_block( + self.mid_block = self.get_mid_block_addon( base_channels=base_block_out_channels[-1], ctrl_channels=block_out_channels[-1], temb_channels=time_embedding_dim, @@ -317,7 +314,7 @@ def __init__( ctrl_skip_channels_ = [ctrl_skip_channels.pop() for _ in range(3)] self.up_connections.append( - self.get_up_connections( + self.get_up_block_addon( out_channels=base_out_channels, prev_output_channel=prev_base_output_channel, ctrl_skip_channels=ctrl_skip_channels_, @@ -325,7 +322,7 @@ def __init__( ) @staticmethod - def get_down_block( + def get_down_block_addon( base_in_channels: int, base_out_channels: int, ctrl_in_channels: int, @@ -415,7 +412,7 @@ def get_down_block( return down_block_components @staticmethod - def get_mid_block( + def get_mid_block_addon( base_channels: int, ctrl_channels: int, temb_channels: Optional[int] = None, @@ -451,7 +448,7 @@ def get_mid_block( ) @staticmethod - def get_up_connections( + def get_up_block_addon( out_channels: int, prev_output_channel: int, ctrl_skip_channels: List[int], @@ -608,12 +605,16 @@ def __init__( "As `UNetControlNetXSModel` currently only supports StableDiffusion and StableDiffusion-XL, `addition_embed_type` must be `None` or `'text_time'`." ) - transformer_layers_per_block = repeat_if_not_list( - transformer_layers_per_block, repetitions=len(down_block_types) - ) - cross_attention_dim = repeat_if_not_list(cross_attention_dim, repetitions=len(down_block_types)) - base_num_attention_heads = repeat_if_not_list(num_attention_heads, repetitions=len(down_block_types)) - ctrl_num_attention_heads = repeat_if_not_list(ctrl_num_attention_heads, repetitions=len(down_block_types)) + if not isinstance(transformer_layers_per_block, (list, tuple)): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + if not isinstance(cross_attention_dim, (list, tuple)): + cross_attention_dim = [cross_attention_dim] * len(down_block_types) + if not isinstance(num_attention_heads, (list, tuple)): + num_attention_heads = [num_attention_heads] * len(down_block_types) + if not isinstance(ctrl_num_attention_heads, (list, tuple)): + ctrl_num_attention_heads = [ctrl_num_attention_heads] * len(down_block_types) + + base_num_attention_heads = num_attention_heads self.in_channels = 4 @@ -1950,7 +1951,3 @@ def find_largest_factor(number, max_factor): if residual == 0: return factor factor -= 1 - - -def repeat_if_not_list(value, repetitions): - return value if isinstance(value, (tuple, list)) else [value] * repetitions diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index dddd1479cd24..3f52a6a679c4 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -452,7 +452,6 @@ def check_inputs( self, prompt, image, - callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, @@ -461,11 +460,6 @@ def check_inputs( control_guidance_end=1.0, callback_on_step_end_tensor_inputs=None, ): - if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): @@ -659,7 +653,6 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - **kwargs, ): r""" The call function to the pipeline for generation. @@ -744,29 +737,12 @@ def __call__( "not-safe-for-work" (nsfw) content. """ - callback = kwargs.pop("callback", None) - callback_steps = kwargs.pop("callback_steps", None) - - if callback is not None: - deprecate( - "callback", - "1.0.0", - "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", - ) - if callback_steps is not None: - deprecate( - "callback_steps", - "1.0.0", - "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", - ) - unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, image, - callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, @@ -818,20 +794,17 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare image - if isinstance(unet, UNetControlNetXSModel): - image = self.prepare_image( - image=image, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=unet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, - ) - height, width = image.shape[-2:] - else: - assert False + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=unet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + height, width = image.shape[-2:] # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -900,12 +873,8 @@ def __call__( prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - if callback is not None and i % callback_steps == 0: - step_idx = i // getattr(self.scheduler, "order", 1) - callback(step_idx, t, latents) # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings From 6ba84ca8828920b2ac16cfdaa1941b1805997892 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Wed, 3 Apr 2024 22:23:39 +0200 Subject: [PATCH 68/75] Removed compatibility check for vae/ctrl embedding --- src/diffusers/models/controlnet_xs.py | 7 ------- .../pipelines/controlnet_xs/pipeline_controlnet_xs.py | 10 ---------- .../controlnet_xs/pipeline_controlnet_xs_sd_xl.py | 10 ---------- 3 files changed, 27 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index a4a4da26c699..716236e13a56 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -883,13 +883,6 @@ def freeze_unet_params(self) -> None: for u in self.up_blocks: u.freeze_base_params() - @torch.no_grad() - def _check_if_vae_compatible(self, vae: AutoencoderKL): - condition_downscale_factor = 2 ** (len(self.config.ctrl_conditioning_embedding_out_channels) - 1) - vae_downscale_factor = 2 ** (len(vae.config.block_out_channels) - 1) - compatible = condition_downscale_factor == vae_downscale_factor - return compatible, condition_downscale_factor, vae_downscale_factor - def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index 3f52a6a679c4..5d882d917df5 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -161,16 +161,6 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - ( - vae_compatible, - cnxs_condition_downsample_factor, - vae_downsample_factor, - ) = unet._check_if_vae_compatible(vae) - if not vae_compatible: - raise ValueError( - f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXSAddon model ({cnxs_condition_downsample_factor}) need to be equal. Consider building the ControlNetXSAddon model with different `conditioning_embedding_out_channels`." - ) - self.register_modules( vae=vae, text_encoder=text_encoder, diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index df81a0073423..644335091072 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -179,16 +179,6 @@ def __init__( if isinstance(unet, UNet2DConditionModel): unet = UNetControlNetXSModel.from_unet(unet, controlnet) - ( - vae_compatible, - cnxs_condition_downsample_factor, - vae_downsample_factor, - ) = unet._check_if_vae_compatible(vae) - if not vae_compatible: - raise ValueError( - f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXSAddon model ({cnxs_condition_downsample_factor}) need to be equal. Consider building the ControlNetXSAddon model with different `conditioning_embedding_out_channels`." - ) - self.register_modules( vae=vae, text_encoder=text_encoder, From 4ded3adedf279de56af0661e76db559c9cb004de Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Wed, 3 Apr 2024 22:24:47 +0200 Subject: [PATCH 69/75] make style, quality, fix-copies --- src/diffusers/models/controlnet_xs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 716236e13a56..a32f5ba12cb5 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -24,7 +24,6 @@ from ..utils import BaseOutput, is_torch_version, logging from ..utils.torch_utils import apply_freeu from .attention_processor import Attention, AttentionProcessor -from .autoencoders import AutoencoderKL from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unets.unet_2d_blocks import ( From c4181602dbd6be7a9bb9019da6e7c766bf8df74e Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Wed, 3 Apr 2024 22:25:26 +0200 Subject: [PATCH 70/75] Delete Pipfile --- Pipfile | 11 ----------- 1 file changed, 11 deletions(-) delete mode 100644 Pipfile diff --git a/Pipfile b/Pipfile deleted file mode 100644 index 0757494bb360..000000000000 --- a/Pipfile +++ /dev/null @@ -1,11 +0,0 @@ -[[source]] -url = "https://pypi.org/simple" -verify_ssl = true -name = "pypi" - -[packages] - -[dev-packages] - -[requires] -python_version = "3.11" From 0de4b44be1c585b97d3e4c84b9403757595478f0 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Mon, 8 Apr 2024 10:20:51 +0200 Subject: [PATCH 71/75] Integrated review feedback - Importing ControlNetConditioningEmbedding now - get_down/mid/up_block_addon now outside class - renamed `do_control` to `apply_control` --- src/diffusers/models/controlnet_xs.py | 357 ++++++++---------- .../controlnet_xs/pipeline_controlnet_xs.py | 4 +- .../pipeline_controlnet_xs_sd_xl.py | 4 +- .../unets/test_models_unet_controlnetxs.py | 2 +- 4 files changed, 161 insertions(+), 206 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index a32f5ba12cb5..d627bb3a0d3a 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -24,6 +24,7 @@ from ..utils import BaseOutput, is_torch_version, logging from ..utils.torch_utils import apply_freeu from .attention_processor import Attention, AttentionProcessor +from .controlnet import ControlNetConditioningEmbedding from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unets.unet_2d_blocks import ( @@ -94,50 +95,144 @@ def __init__(self, ctrl_to_base: nn.ModuleList): self.ctrl_to_base = ctrl_to_base -# copied from diffusers.models.controlnet.ControlNetConditioningEmbedding -class ControlNetConditioningEmbedding(nn.Module): - """ - Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN - [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized - training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the - convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides - (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full - model) to encode image-space conditions ... into feature maps ..." - """ - - def __init__( - self, - conditioning_embedding_channels: int, - conditioning_channels: int = 3, - block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), - ): - super().__init__() +def get_down_block_addon( + base_in_channels: int, + base_out_channels: int, + ctrl_in_channels: int, + ctrl_out_channels: int, + temb_channels: int, + max_norm_num_groups: Optional[int] = 32, + has_crossattn=True, + transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1, + num_attention_heads: Optional[int] = 1, + cross_attention_dim: Optional[int] = 1024, + add_downsample: bool = True, + upcast_attention: Optional[bool] = False, +): + num_layers = 2 # only support sd + sdxl + + resnets = [] + attentions = [] + ctrl_to_base = [] + base_to_ctrl = [] + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + base_in_channels = base_in_channels if i == 0 else base_out_channels + ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels + + # Before the resnet/attention application, information is concatted from base to control. + # Concat doesn't require change in number of channels + base_to_ctrl.append(make_zero_conv(base_in_channels, base_in_channels)) + + resnets.append( + ResnetBlock2D( + in_channels=ctrl_in_channels + base_in_channels, # information from base is concatted to ctrl + out_channels=ctrl_out_channels, + temb_channels=temb_channels, + groups=find_largest_factor(ctrl_in_channels + base_in_channels, max_factor=max_norm_num_groups), + groups_out=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups), + eps=1e-5, + ) + ) - self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + if has_crossattn: + attentions.append( + Transformer2DModel( + num_attention_heads, + ctrl_out_channels // num_attention_heads, + in_channels=ctrl_out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + use_linear_projection=True, + upcast_attention=upcast_attention, + norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups), + ) + ) - self.blocks = nn.ModuleList([]) + # After the resnet/attention application, information is added from control to base + # Addition requires change in number of channels + ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) - for i in range(len(block_out_channels) - 1): - channel_in = block_out_channels[i] - channel_out = block_out_channels[i + 1] - self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) - self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + if add_downsample: + # Before the downsampler application, information is concatted from base to control + # Concat doesn't require change in number of channels + base_to_ctrl.append(make_zero_conv(base_out_channels, base_out_channels)) - self.conv_out = zero_module( - nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) + downsamplers = Downsample2D( + ctrl_out_channels + base_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op" ) - def forward(self, conditioning): - embedding = self.conv_in(conditioning) - embedding = F.silu(embedding) - - for block in self.blocks: - embedding = block(embedding) - embedding = F.silu(embedding) - - embedding = self.conv_out(embedding) - - return embedding + # After the downsampler application, information is added from control to base + # Addition requires change in number of channels + ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) + else: + downsamplers = None + + down_block_components = ControlNetXSAddonDownBlockComponents( + resnets=nn.ModuleList(resnets), + base_to_ctrl=nn.ModuleList(base_to_ctrl), + ctrl_to_base=nn.ModuleList(ctrl_to_base), + ) + + if has_crossattn: + down_block_components.attentions = nn.ModuleList(attentions) + if downsamplers is not None: + down_block_components.downsamplers = downsamplers + + return down_block_components + + +def get_mid_block_addon( + base_channels: int, + ctrl_channels: int, + temb_channels: Optional[int] = None, + max_norm_num_groups: Optional[int] = 32, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = 1, + cross_attention_dim: Optional[int] = 1024, + upcast_attention: bool = False, +): + # Before the midblock application, information is concatted from base to control. + # Concat doesn't require change in number of channels + base_to_ctrl = make_zero_conv(base_channels, base_channels) + + midblock = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block, + in_channels=ctrl_channels + base_channels, + out_channels=ctrl_channels, + temb_channels=temb_channels, + # number or norm groups must divide both in_channels and out_channels + resnet_groups=find_largest_factor(gcd(ctrl_channels, ctrl_channels + base_channels), max_norm_num_groups), + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + use_linear_projection=True, + upcast_attention=upcast_attention, + ) + + # After the midblock application, information is added from control to base + # Addition requires change in number of channels + ctrl_to_base = make_zero_conv(ctrl_channels, base_channels) + + return ControlNetXSAddonMidBlockComponents( + base_to_ctrl=base_to_ctrl, midblock=midblock, ctrl_to_base=ctrl_to_base + ) + + +def get_up_block_addon( + out_channels: int, + prev_output_channel: int, + ctrl_skip_channels: List[int], +): + ctrl_to_base = [] + num_layers = 3 # only support sd + sdxl + for i in range(num_layers): + resnet_in_channels = prev_output_channel if i == 0 else out_channels + ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels)) + + return ControlNetXSAddonUpBlockComponents(ctrl_to_base=nn.ModuleList(ctrl_to_base)) class ControlNetXSAddon(ModelMixin, ConfigMixin): @@ -268,7 +363,7 @@ def __init__( is_final_block = i == len(down_block_types) - 1 self.down_blocks.append( - self.get_down_block_addon( + get_down_block_addon( base_in_channels=base_in_channels, base_out_channels=base_out_channels, ctrl_in_channels=ctrl_in_channels, @@ -285,7 +380,7 @@ def __init__( ) # mid - self.mid_block = self.get_mid_block_addon( + self.mid_block = get_mid_block_addon( base_channels=base_block_out_channels[-1], ctrl_channels=block_out_channels[-1], temb_channels=time_embedding_dim, @@ -313,153 +408,13 @@ def __init__( ctrl_skip_channels_ = [ctrl_skip_channels.pop() for _ in range(3)] self.up_connections.append( - self.get_up_block_addon( + get_up_block_addon( out_channels=base_out_channels, prev_output_channel=prev_base_output_channel, ctrl_skip_channels=ctrl_skip_channels_, ) ) - @staticmethod - def get_down_block_addon( - base_in_channels: int, - base_out_channels: int, - ctrl_in_channels: int, - ctrl_out_channels: int, - temb_channels: int, - max_norm_num_groups: Optional[int] = 32, - has_crossattn=True, - transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1, - num_attention_heads: Optional[int] = 1, - cross_attention_dim: Optional[int] = 1024, - add_downsample: bool = True, - upcast_attention: Optional[bool] = False, - ): - num_layers = 2 # only support sd + sdxl - - resnets = [] - attentions = [] - ctrl_to_base = [] - base_to_ctrl = [] - - if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = [transformer_layers_per_block] * num_layers - - for i in range(num_layers): - base_in_channels = base_in_channels if i == 0 else base_out_channels - ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels - - # Before the resnet/attention application, information is concatted from base to control. - # Concat doesn't require change in number of channels - base_to_ctrl.append(make_zero_conv(base_in_channels, base_in_channels)) - - resnets.append( - ResnetBlock2D( - in_channels=ctrl_in_channels + base_in_channels, # information from base is concatted to ctrl - out_channels=ctrl_out_channels, - temb_channels=temb_channels, - groups=find_largest_factor(ctrl_in_channels + base_in_channels, max_factor=max_norm_num_groups), - groups_out=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups), - eps=1e-5, - ) - ) - - if has_crossattn: - attentions.append( - Transformer2DModel( - num_attention_heads, - ctrl_out_channels // num_attention_heads, - in_channels=ctrl_out_channels, - num_layers=transformer_layers_per_block[i], - cross_attention_dim=cross_attention_dim, - use_linear_projection=True, - upcast_attention=upcast_attention, - norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups), - ) - ) - - # After the resnet/attention application, information is added from control to base - # Addition requires change in number of channels - ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) - - if add_downsample: - # Before the downsampler application, information is concatted from base to control - # Concat doesn't require change in number of channels - base_to_ctrl.append(make_zero_conv(base_out_channels, base_out_channels)) - - downsamplers = Downsample2D( - ctrl_out_channels + base_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op" - ) - - # After the downsampler application, information is added from control to base - # Addition requires change in number of channels - ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) - else: - downsamplers = None - - down_block_components = ControlNetXSAddonDownBlockComponents( - resnets=nn.ModuleList(resnets), - base_to_ctrl=nn.ModuleList(base_to_ctrl), - ctrl_to_base=nn.ModuleList(ctrl_to_base), - ) - - if has_crossattn: - down_block_components.attentions = nn.ModuleList(attentions) - if downsamplers is not None: - down_block_components.downsamplers = downsamplers - - return down_block_components - - @staticmethod - def get_mid_block_addon( - base_channels: int, - ctrl_channels: int, - temb_channels: Optional[int] = None, - max_norm_num_groups: Optional[int] = 32, - transformer_layers_per_block: int = 1, - num_attention_heads: Optional[int] = 1, - cross_attention_dim: Optional[int] = 1024, - upcast_attention: bool = False, - ): - # Before the midblock application, information is concatted from base to control. - # Concat doesn't require change in number of channels - base_to_ctrl = make_zero_conv(base_channels, base_channels) - - midblock = UNetMidBlock2DCrossAttn( - transformer_layers_per_block=transformer_layers_per_block, - in_channels=ctrl_channels + base_channels, - out_channels=ctrl_channels, - temb_channels=temb_channels, - # number or norm groups must divide both in_channels and out_channels - resnet_groups=find_largest_factor(gcd(ctrl_channels, ctrl_channels + base_channels), max_norm_num_groups), - cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads, - use_linear_projection=True, - upcast_attention=upcast_attention, - ) - - # After the midblock application, information is added from control to base - # Addition requires change in number of channels - ctrl_to_base = make_zero_conv(ctrl_channels, base_channels) - - return ControlNetXSAddonMidBlockComponents( - base_to_ctrl=base_to_ctrl, midblock=midblock, ctrl_to_base=ctrl_to_base - ) - - @staticmethod - def get_up_block_addon( - out_channels: int, - prev_output_channel: int, - ctrl_skip_channels: List[int], - ): - ctrl_to_base = [] - num_layers = 3 # only support sd + sdxl - for i in range(num_layers): - resnet_in_channels = prev_output_channel if i == 0 else out_channels - ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels)) - - return ControlNetXSAddonUpBlockComponents(ctrl_to_base=nn.ModuleList(ctrl_to_base)) - @classmethod def from_unet( cls, @@ -1031,7 +986,7 @@ def forward( cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, return_dict: bool = True, - do_control: bool = True, + apply_control: bool = True, ) -> Union[ControlNetXSOutput, Tuple]: """ The [`ControlNetXSModel`] forward method. @@ -1063,7 +1018,7 @@ def forward( Additional conditions for the Stable Diffusion XL UNet. return_dict (`bool`, defaults to `True`): Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. - do_control (`bool`, defaults to `True`): + apply_control (`bool`, defaults to `True`): If `False`, the input is run only through the base model. Returns: @@ -1105,7 +1060,7 @@ def forward( # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) - if self.config.ctrl_learn_time_embedding and do_control: + if self.config.ctrl_learn_time_embedding and apply_control: ctrl_temb = self.ctrl_time_embedding(t_emb, timestep_cond) base_temb = self.base_time_embedding(t_emb, timestep_cond) interpolation_param = self.config.time_embedding_mix**0.3 @@ -1159,7 +1114,7 @@ def forward( h_ctrl = self.ctrl_conv_in(h_ctrl) if guided_hint is not None: h_ctrl += guided_hint - if do_control: + if apply_control: h_base = h_base + self.control_to_base_for_conv_in(h_ctrl) * conditioning_scale # add ctrl -> base hs_base.append(h_base) @@ -1174,7 +1129,7 @@ def forward( conditioning_scale=conditioning_scale, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, - do_control=do_control, + apply_control=apply_control, ) hs_base.extend(residual_hb) hs_ctrl.extend(residual_hc) @@ -1188,7 +1143,7 @@ def forward( conditioning_scale=conditioning_scale, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, - do_control=do_control, + apply_control=apply_control, ) # 3 - up @@ -1207,7 +1162,7 @@ def forward( conditioning_scale=conditioning_scale, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, - do_control=do_control, + apply_control=apply_control, ) # 4 - conv out @@ -1422,7 +1377,7 @@ def forward( attention_mask: Optional[FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[FloatTensor] = None, - do_control: bool = True, + apply_control: bool = True, ) -> Tuple[FloatTensor, FloatTensor, Tuple[FloatTensor, ...], Tuple[FloatTensor, ...]]: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: @@ -1450,7 +1405,7 @@ def custom_forward(*inputs): base_blocks, ctrl_blocks, self.base_to_ctrl, self.ctrl_to_base ): # concat base -> ctrl - if do_control: + if apply_control: h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) # apply base subblock @@ -1476,7 +1431,7 @@ def custom_forward(*inputs): )[0] # apply ctrl subblock - if do_control: + if apply_control: if self.training and self.gradient_checkpointing: ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} h_ctrl = torch.utils.checkpoint.checkpoint( @@ -1498,7 +1453,7 @@ def custom_forward(*inputs): )[0] # add ctrl -> base - if do_control: + if apply_control: h_base = h_base + c2b(h_ctrl) * conditioning_scale base_output_states = base_output_states + (h_base,) @@ -1509,15 +1464,15 @@ def custom_forward(*inputs): c2b = self.ctrl_to_base[-1] # concat base -> ctrl - if do_control: + if apply_control: h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) # apply base subblock h_base = self.base_downsamplers(h_base) # apply ctrl subblock - if do_control: + if apply_control: h_ctrl = self.ctrl_downsamplers(h_ctrl) # add ctrl -> base - if do_control: + if apply_control: h_base = h_base + c2b(h_ctrl) * conditioning_scale base_output_states = base_output_states + (h_base,) @@ -1640,7 +1595,7 @@ def forward( cross_attention_kwargs: Optional[Dict[str, Any]] = None, attention_mask: Optional[FloatTensor] = None, encoder_attention_mask: Optional[FloatTensor] = None, - do_control: bool = True, + apply_control: bool = True, ) -> Tuple[FloatTensor, FloatTensor]: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: @@ -1657,10 +1612,10 @@ def forward( "encoder_attention_mask": encoder_attention_mask, } - if do_control: + if apply_control: h_ctrl = torch.cat([h_ctrl, self.base_to_ctrl(h_base)], dim=1) # concat base -> ctrl h_base = self.base_midblock(h_base, **joint_args) # apply base mid block - if do_control: + if apply_control: h_ctrl = self.ctrl_midblock(h_ctrl, **joint_args) # apply ctrl mid block h_base = h_base + self.ctrl_to_base(h_ctrl) * conditioning_scale # add ctrl -> base @@ -1818,7 +1773,7 @@ def forward( attention_mask: Optional[FloatTensor] = None, upsample_size: Optional[int] = None, encoder_attention_mask: Optional[FloatTensor] = None, - do_control: bool = True, + apply_control: bool = True, ) -> FloatTensor: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: @@ -1873,7 +1828,7 @@ def maybe_apply_freeu_to_subblock(hidden_states, res_h_base): reversed(res_hidden_states_tuple_base), reversed(res_hidden_states_tuple_ctrl), ): - if do_control: + if apply_control: hidden_states += c2b(res_h_ctrl) * conditioning_scale hidden_states, res_h_base = maybe_apply_freeu_to_subblock(hidden_states, res_h_base) @@ -1904,7 +1859,7 @@ def maybe_apply_freeu_to_subblock(hidden_states, res_h_base): c2b = self.ctrl_to_base[-1] res_h_base = res_hidden_states_tuple_base[0] res_h_ctrl = res_hidden_states_tuple_ctrl[0] - if do_control: + if apply_control: hidden_states += c2b(res_h_ctrl) * conditioning_scale hidden_states, res_h_base = maybe_apply_freeu_to_subblock(hidden_states, res_h_base) hidden_states = torch.cat([hidden_states, res_h_base], dim=1) diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index 5d882d917df5..8ab5b74686b8 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -832,7 +832,7 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - do_control = ( + apply_control = ( i / len(timesteps) >= control_guidance_start and (i + 1) / len(timesteps) <= control_guidance_end ) noise_pred = self.unet( @@ -843,7 +843,7 @@ def __call__( conditioning_scale=controlnet_conditioning_scale, cross_attention_kwargs=cross_attention_kwargs, return_dict=True, - do_control=do_control, + apply_control=apply_control, ).sample # perform guidance diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index 644335091072..697b5a17364d 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -1048,7 +1048,7 @@ def __call__( added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} # predict the noise residual - do_control = ( + apply_control = ( i / len(timesteps) >= control_guidance_start and (i + 1) / len(timesteps) <= control_guidance_end ) noise_pred = self.unet( @@ -1060,7 +1060,7 @@ def __call__( cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=True, - do_control=do_control, + apply_control=apply_control, ).sample # perform guidance diff --git a/tests/models/unets/test_models_unet_controlnetxs.py b/tests/models/unets/test_models_unet_controlnetxs.py index eb93ce8cfdf8..09c134533209 100644 --- a/tests/models/unets/test_models_unet_controlnetxs.py +++ b/tests/models/unets/test_models_unet_controlnetxs.py @@ -312,7 +312,7 @@ def test_forward_no_control(self): with torch.no_grad(): unet_output = unet(**input_for_unet).sample.cpu() - unet_controlnet_output = model(**input_, do_control=False).sample.cpu() + unet_controlnet_output = model(**input_, apply_control=False).sample.cpu() assert np.abs(unet_output.flatten() - unet_controlnet_output.flatten()).max() < 3e-4 From 1d8cad8eec04e3da01d86e0e911d240f9ec975aa Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Mon, 8 Apr 2024 19:16:35 +0200 Subject: [PATCH 72/75] Reduced size of test tensors For this, added `norm_num_groups` as parameter everywhere --- src/diffusers/models/controlnet_xs.py | 49 ++++++++++++------ .../unets/test_models_unet_controlnetxs.py | 51 +++++++++++-------- .../controlnet_xs/test_controlnetxs.py | 22 ++++---- .../controlnet_xs/test_controlnetxs_sdxl.py | 20 ++++---- 4 files changed, 86 insertions(+), 56 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index d627bb3a0d3a..165c8e8273e7 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -18,7 +18,6 @@ import torch import torch.utils.checkpoint from torch import FloatTensor, nn -from torch.nn import functional as F from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, is_torch_version, logging @@ -216,9 +215,7 @@ def get_mid_block_addon( # Addition requires change in number of channels ctrl_to_base = make_zero_conv(ctrl_channels, base_channels) - return ControlNetXSAddonMidBlockComponents( - base_to_ctrl=base_to_ctrl, midblock=midblock, ctrl_to_base=ctrl_to_base - ) + return ControlNetXSAddonMidBlockComponents(base_to_ctrl=base_to_ctrl, midblock=midblock, ctrl_to_base=ctrl_to_base) def get_up_block_addon( @@ -620,7 +617,8 @@ def __init__( ctrl_in_channels=ctrl_in_channels, ctrl_out_channels=ctrl_out_channels, temb_channels=time_embed_dim, - max_norm_num_groups=ctrl_max_norm_num_groups, + norm_num_groups=norm_num_groups, + ctrl_max_norm_num_groups=ctrl_max_norm_num_groups, has_crossattn=has_crossattn, transformer_layers_per_block=transformer_layers_per_block[i], base_num_attention_heads=base_num_attention_heads[i], @@ -636,6 +634,8 @@ def __init__( base_channels=block_out_channels[-1], ctrl_channels=ctrl_block_out_channels[-1], temb_channels=time_embed_dim, + norm_num_groups=norm_num_groups, + ctrl_max_norm_num_groups=ctrl_max_norm_num_groups, transformer_layers_per_block=transformer_layers_per_block[-1], base_num_attention_heads=base_num_attention_heads[-1], ctrl_num_attention_heads=ctrl_num_attention_heads[-1], @@ -683,6 +683,7 @@ def __init__( cross_attention_dim=rev_cross_attention_dim[-1], add_upsample=not is_final_block, upcast_attention=upcast_attention, + norm_num_groups=norm_num_groups, ) ) @@ -1184,7 +1185,8 @@ def __init__( ctrl_in_channels: int, ctrl_out_channels: int, temb_channels: int, - max_norm_num_groups: Optional[int] = 32, + norm_num_groups: int = 32, + ctrl_max_norm_num_groups: int = 32, has_crossattn=True, transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1, base_num_attention_heads: Optional[int] = 1, @@ -1219,6 +1221,7 @@ def __init__( in_channels=base_in_channels, out_channels=base_out_channels, temb_channels=temb_channels, + groups=norm_num_groups, ) ) ctrl_resnets.append( @@ -1226,8 +1229,10 @@ def __init__( in_channels=ctrl_in_channels + base_in_channels, # information from base is concatted to ctrl out_channels=ctrl_out_channels, temb_channels=temb_channels, - groups=find_largest_factor(ctrl_in_channels + base_in_channels, max_factor=max_norm_num_groups), - groups_out=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups), + groups=find_largest_factor( + ctrl_in_channels + base_in_channels, max_factor=ctrl_max_norm_num_groups + ), + groups_out=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups), eps=1e-5, ) ) @@ -1242,6 +1247,7 @@ def __init__( cross_attention_dim=cross_attention_dim, use_linear_projection=True, upcast_attention=upcast_attention, + norm_num_groups=norm_num_groups, ) ) ctrl_attentions.append( @@ -1253,7 +1259,7 @@ def __init__( cross_attention_dim=cross_attention_dim, use_linear_projection=True, upcast_attention=upcast_attention, - norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups), + norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups), ) ) @@ -1302,7 +1308,8 @@ def get_first_cross_attention(block): ) # base channels are concatted to ctrl channels in init ctrl_out_channels = ctrl_downblock.resnets[0].out_channels temb_channels = base_downblock.resnets[0].time_emb_proj.in_features - num_groups = ctrl_downblock.resnets[0].norm1.num_groups + num_groups = base_downblock.resnets[0].norm1.num_groups + ctrl_num_groups = ctrl_downblock.resnets[0].norm1.num_groups if hasattr(base_downblock, "attentions"): has_crossattn = True transformer_layers_per_block = len(base_downblock.attentions[0].transformer_blocks) @@ -1326,7 +1333,8 @@ def get_first_cross_attention(block): ctrl_in_channels=ctrl_in_channels, ctrl_out_channels=ctrl_out_channels, temb_channels=temb_channels, - max_norm_num_groups=num_groups, + norm_num_groups=num_groups, + ctrl_max_norm_num_groups=ctrl_num_groups, has_crossattn=has_crossattn, transformer_layers_per_block=transformer_layers_per_block, base_num_attention_heads=base_num_attention_heads, @@ -1487,7 +1495,8 @@ def __init__( base_channels: int, ctrl_channels: int, temb_channels: Optional[int] = None, - max_norm_num_groups: Optional[int] = 32, + norm_num_groups: int = 32, + ctrl_max_norm_num_groups: int = 32, transformer_layers_per_block: int = 1, base_num_attention_heads: Optional[int] = 1, ctrl_num_attention_heads: Optional[int] = 1, @@ -1504,6 +1513,7 @@ def __init__( transformer_layers_per_block=transformer_layers_per_block, in_channels=base_channels, temb_channels=temb_channels, + resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=base_num_attention_heads, use_linear_projection=True, @@ -1516,7 +1526,9 @@ def __init__( out_channels=ctrl_channels, temb_channels=temb_channels, # number or norm groups must divide both in_channels and out_channels - resnet_groups=find_largest_factor(gcd(ctrl_channels, ctrl_channels + base_channels), max_norm_num_groups), + resnet_groups=find_largest_factor( + gcd(ctrl_channels, ctrl_channels + base_channels), ctrl_max_norm_num_groups + ), cross_attention_dim=cross_attention_dim, num_attention_heads=ctrl_num_attention_heads, use_linear_projection=True, @@ -1547,7 +1559,8 @@ def get_first_cross_attention(midblock): ctrl_channels = ctrl_to_base.in_channels transformer_layers_per_block = len(base_midblock.attentions[0].transformer_blocks) temb_channels = base_midblock.resnets[0].time_emb_proj.in_features - num_groups = ctrl_midblock.resnets[0].norm1.num_groups + num_groups = base_midblock.resnets[0].norm1.num_groups + ctrl_num_groups = ctrl_midblock.resnets[0].norm1.num_groups base_num_attention_heads = get_first_cross_attention(base_midblock).heads ctrl_num_attention_heads = get_first_cross_attention(ctrl_midblock).heads cross_attention_dim = get_first_cross_attention(base_midblock).cross_attention_dim @@ -1558,7 +1571,8 @@ def get_first_cross_attention(midblock): base_channels=base_channels, ctrl_channels=ctrl_channels, temb_channels=temb_channels, - max_norm_num_groups=num_groups, + norm_num_groups=num_groups, + ctrl_max_norm_num_groups=ctrl_num_groups, transformer_layers_per_block=transformer_layers_per_block, base_num_attention_heads=base_num_attention_heads, ctrl_num_attention_heads=ctrl_num_attention_heads, @@ -1630,6 +1644,7 @@ def __init__( prev_output_channel: int, ctrl_skip_channels: List[int], temb_channels: int, + norm_num_groups: int = 32, resolution_idx: Optional[int] = None, has_crossattn=True, transformer_layers_per_block: int = 1, @@ -1662,6 +1677,7 @@ def __init__( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, + groups=norm_num_groups, ) ) @@ -1675,6 +1691,7 @@ def __init__( cross_attention_dim=cross_attention_dim, use_linear_projection=True, upcast_attention=upcast_attention, + norm_num_groups=norm_num_groups, ) ) @@ -1703,6 +1720,7 @@ def get_first_cross_attention(block): prev_output_channels = base_upblock.resnets[0].in_channels - out_channels ctrl_skip_channelss = [c.in_channels for c in ctrl_to_base_skip_connections] temb_channels = base_upblock.resnets[0].time_emb_proj.in_features + num_groups = base_upblock.resnets[0].norm1.num_groups resolution_idx = base_upblock.resolution_idx if hasattr(base_upblock, "attentions"): has_crossattn = True @@ -1725,6 +1743,7 @@ def get_first_cross_attention(block): prev_output_channel=prev_output_channels, ctrl_skip_channels=ctrl_skip_channelss, temb_channels=temb_channels, + norm_num_groups=num_groups, resolution_idx=resolution_idx, has_crossattn=has_crossattn, transformer_layers_per_block=transformer_layers_per_block, diff --git a/tests/models/unets/test_models_unet_controlnetxs.py b/tests/models/unets/test_models_unet_controlnetxs.py index 09c134533209..b41c07508aa1 100644 --- a/tests/models/unets/test_models_unet_controlnetxs.py +++ b/tests/models/unets/test_models_unet_controlnetxs.py @@ -44,12 +44,13 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes def dummy_input(self): batch_size = 4 num_channels = 4 - sizes = (32, 32) + sizes = (16, 16) + conditioning_image_size = (3, 32, 32) # size of additional, unprocessed image for control-conditioning noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) time_step = torch.tensor([10]).to(torch_device) - encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device) - controlnet_cond = floats_tensor((batch_size, 3, 256, 256)).to(torch_device) + encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device) + controlnet_cond = floats_tensor((batch_size, *conditioning_image_size)).to(torch_device) conditioning_scale = 1 return { @@ -62,46 +63,54 @@ def dummy_input(self): @property def input_shape(self): - return (4, 32, 32) + return (4, 16, 16) @property def output_shape(self): - return (4, 32, 32) + return (4, 16, 16) def prepare_init_args_and_inputs_for_common(self): init_dict = { - "sample_size": 32, + "sample_size": 16, "down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"), "up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"), - "block_out_channels": (32, 64), - "cross_attention_dim": 32, + "block_out_channels": (4, 8), + "cross_attention_dim": 8, "transformer_layers_per_block": 1, - "num_attention_heads": 8, + "num_attention_heads": 2, + "norm_num_groups": 4, "upcast_attention": False, - "ctrl_block_out_channels": [4, 8], - "ctrl_num_attention_heads": 8, - "ctrl_max_norm_num_groups": 4, + "ctrl_block_out_channels": [2, 4], + "ctrl_num_attention_heads": 4, + "ctrl_max_norm_num_groups": 2, + "ctrl_conditioning_embedding_out_channels": (2, 2), } inputs_dict = self.dummy_input return init_dict, inputs_dict def get_dummy_unet(self): - """For some tests we also need the underlying UNet. For these, we'll build the UNetControlNetXSModel from the UNet""" + """For some tests we also need the underlying UNet. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Addon""" return UNet2DConditionModel( - block_out_channels=(32, 64), + block_out_channels=(4, 8), layers_per_block=2, - sample_size=32, + sample_size=16, in_channels=4, out_channels=4, down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), - cross_attention_dim=32, + cross_attention_dim=8, + norm_num_groups=4, use_linear_projection=True, ) + def get_dummy_controlnet_from_unet(self, unet, **kwargs): + """For some tests we also need the underlying ControlNetXS-Addon. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Addon""" + # size_ratio and conditioning_embedding_out_channels chosen to keep model small + return ControlNetXSAddon.from_unet(unet, size_ratio=1, conditioning_embedding_out_channels=(2, 2), **kwargs) + def test_from_unet(self): unet = self.get_dummy_unet() - controlnet = ControlNetXSAddon.from_unet(unet, size_ratio=1) + controlnet = self.get_dummy_controlnet_from_unet(unet) model = UNetControlNetXSModel.from_unet(unet, controlnet) model_state_dict = model.state_dict() @@ -298,7 +307,7 @@ def _set_gradient_checkpointing_new(self, module, value=False): def test_forward_no_control(self): unet = self.get_dummy_unet() - controlnet = ControlNetXSAddon.from_unet(unet, size_ratio=1) + controlnet = self.get_dummy_controlnet_from_unet(unet) model = UNetControlNetXSModel.from_unet(unet, controlnet) @@ -318,9 +327,9 @@ def test_forward_no_control(self): def test_time_embedding_mixing(self): unet = self.get_dummy_unet() - controlnet = ControlNetXSAddon.from_unet(unet, size_ratio=1) - controlnet_mix_time = ControlNetXSAddon.from_unet( - unet, size_ratio=1, time_embedding_mix=0.5, learn_time_embedding=True + controlnet = self.get_dummy_controlnet_from_unet(unet) + controlnet_mix_time = self.get_dummy_controlnet_from_unet( + unet, time_embedding_mix=0.5, learn_time_embedding=True ) model = UNetControlNetXSModel.from_unet(unet, controlnet) diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py index e91f9d8d313a..a619f7d07a97 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py @@ -134,24 +134,24 @@ class ControlNetXSPipelineFastTests( def get_dummy_components(self, time_cond_proj_dim=None): torch.manual_seed(0) unet = UNet2DConditionModel( - block_out_channels=(32, 64), + block_out_channels=(4, 8), layers_per_block=2, - sample_size=32, + sample_size=16, in_channels=4, out_channels=4, down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), - cross_attention_dim=32, + cross_attention_dim=8, + norm_num_groups=4, time_cond_proj_dim=time_cond_proj_dim, use_linear_projection=True, ) torch.manual_seed(0) controlnet = ControlNetXSAddon.from_unet( unet=unet, - size_ratio=0.5, - num_attention_heads=2, + size_ratio=1, learn_time_embedding=True, - conditioning_embedding_out_channels=(16, 32), + conditioning_embedding_out_channels=(2, 2), ) torch.manual_seed(0) scheduler = DDIMScheduler( @@ -175,7 +175,7 @@ def get_dummy_components(self, time_cond_proj_dim=None): text_encoder_config = CLIPTextConfig( bos_token_id=0, eos_token_id=2, - hidden_size=32, + hidden_size=8, intermediate_size=37, layer_norm_eps=1e-05, num_attention_heads=4, @@ -206,7 +206,7 @@ def get_dummy_inputs(self, device, seed=0): controlnet_embedder_scale_factor = 2 image = randn_tensor( - (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + (1, 3, 8 * controlnet_embedder_scale_factor, 8 * controlnet_embedder_scale_factor), generator=generator, device=torch.device(device), ) @@ -235,7 +235,7 @@ def test_inference_batch_single_identical(self): def test_controlnet_lcm(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator - components = self.get_dummy_components(time_cond_proj_dim=256) + components = self.get_dummy_components(time_cond_proj_dim=8) sd_pipe = StableDiffusionControlNetXSPipeline(**components) sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config) sd_pipe = sd_pipe.to(torch_device) @@ -247,8 +247,8 @@ def test_controlnet_lcm(self): image_slice = image[0, -3:, -3:, -1] - assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.491, 0.411, 0.292, 0.631, 0.506, 0.439, 0.664, 0.67, 0.447]) + assert image.shape == (1, 16, 16, 3) + expected_slice = np.array([0.745, 0.753, 0.767, 0.543, 0.523, 0.502, 0.314, 0.521, 0.478]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py index d3c846041618..4d3d5071f13d 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py @@ -75,28 +75,29 @@ class StableDiffusionXLControlNetXSPipelineFastTests( def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( - block_out_channels=(32, 64), + block_out_channels=(4, 8), layers_per_block=2, - sample_size=32, + sample_size=16, in_channels=4, out_channels=4, down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), use_linear_projection=True, + norm_num_groups=4, # SD2-specific config below attention_head_dim=(2, 4), addition_embed_type="text_time", addition_time_embed_dim=8, transformer_layers_per_block=(1, 2), - projection_class_embeddings_input_dim=80, # 6 * 8 + 32 - cross_attention_dim=64, + projection_class_embeddings_input_dim=56, # 6 * 8 (addition_time_embed_dim) + 8 (cross_attention_dim) + cross_attention_dim=8, ) torch.manual_seed(0) controlnet = ControlNetXSAddon.from_unet( unet=unet, size_ratio=0.5, learn_time_embedding=True, - conditioning_embedding_out_channels=(16, 32), + conditioning_embedding_out_channels=(2, 2), ) torch.manual_seed(0) scheduler = EulerDiscreteScheduler( @@ -108,18 +109,19 @@ def get_dummy_components(self): ) torch.manual_seed(0) vae = AutoencoderKL( - block_out_channels=[32, 64], + block_out_channels=[4, 8], in_channels=3, out_channels=3, down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], latent_channels=4, + norm_num_groups=2, ) torch.manual_seed(0) text_encoder_config = CLIPTextConfig( bos_token_id=0, eos_token_id=2, - hidden_size=32, + hidden_size=4, intermediate_size=37, layer_norm_eps=1e-05, num_attention_heads=4, @@ -128,7 +130,7 @@ def get_dummy_components(self): vocab_size=1000, # SD2-specific config below hidden_act="gelu", - projection_dim=32, + projection_dim=8, ) text_encoder = CLIPTextModel(text_encoder_config) tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") @@ -158,7 +160,7 @@ def get_dummy_inputs(self, device, seed=0): controlnet_embedder_scale_factor = 2 image = randn_tensor( - (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + (1, 3, 8 * controlnet_embedder_scale_factor, 8 * controlnet_embedder_scale_factor), generator=generator, device=torch.device(device), ) From f334238ceb02e6fd428229080ee0a41e223d57fa Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Tue, 9 Apr 2024 13:55:11 +0200 Subject: [PATCH 73/75] Renamed cnxs-`Addon` to cnxs-`Adapter` - `ControlNetXSAddon` -> `ControlNetXSAdapter` - `ControlNetXSAddonDownBlockComponents` -> `DownBlockControlNetXSAdapter`, and similarly for mid/up - `get_mid_block_addon` -> `get_mid_block_adapter`, and similarly for mid/up --- src/diffusers/__init__.py | 4 +- src/diffusers/models/__init__.py | 4 +- src/diffusers/models/controlnet_xs.py | 92 ++++++++++--------- .../controlnet_xs/pipeline_controlnet_xs.py | 14 +-- .../pipeline_controlnet_xs_sd_xl.py | 14 +-- src/diffusers/utils/dummy_pt_objects.py | 2 +- .../unets/test_models_unet_controlnetxs.py | 8 +- .../controlnet_xs/test_controlnetxs.py | 10 +- .../controlnet_xs/test_controlnetxs_sdxl.py | 8 +- 9 files changed, 80 insertions(+), 76 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 8a1d4161bce0..5d6761663938 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -80,7 +80,7 @@ "AutoencoderTiny", "ConsistencyDecoderVAE", "ControlNetModel", - "ControlNetXSAddon", + "ControlNetXSAdapter", "I2VGenXLUNet", "Kandinsky3UNet", "ModelMixin", @@ -478,7 +478,7 @@ AutoencoderTiny, ConsistencyDecoderVAE, ControlNetModel, - ControlNetXSAddon, + ControlNetXSAdapter, I2VGenXLUNet, Kandinsky3UNet, ModelMixin, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index e5dc43d5792d..78b0efff921d 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -32,7 +32,7 @@ _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["controlnet"] = ["ControlNetModel"] - _import_structure["controlnet_xs"] = ["ControlNetXSAddon", "UNetControlNetXSModel"] + _import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] _import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"] _import_structure["embeddings"] = ["ImageProjection"] _import_structure["modeling_utils"] = ["ModelMixin"] @@ -69,7 +69,7 @@ ConsistencyDecoderVAE, ) from .controlnet import ControlNetModel - from .controlnet_xs import ControlNetXSAddon, UNetControlNetXSModel + from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel from .embeddings import ImageProjection from .modeling_utils import ModelMixin from .transformers import ( diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 165c8e8273e7..7bb1c4e7977c 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -55,7 +55,7 @@ class ControlNetXSOutput(BaseOutput): sample: FloatTensor = None -class ControlNetXSAddonDownBlockComponents(nn.Module): +class DownBlockControlNetXSAdapter(nn.Module): """Components that together with corresponding components from the base model will form a `ControlNetXSCrossAttnDownBlock2D`""" @@ -75,7 +75,7 @@ def __init__( self.downsamplers = downsampler -class ControlNetXSAddonMidBlockComponents(nn.Module): +class MidBlockControlNetXSAdapter(nn.Module): """Components that together with corresponding components from the base model will form a `ControlNetXSCrossAttnMidBlock2D`""" @@ -86,7 +86,7 @@ def __init__(self, midblock: UNetMidBlock2DCrossAttn, base_to_ctrl: nn.ModuleLis self.ctrl_to_base = ctrl_to_base -class ControlNetXSAddonUpBlockComponents(nn.Module): +class UpBlockControlNetXSAdapter(nn.Module): """Components that together with corresponding components from the base model will form a `ControlNetXSCrossAttnUpBlock2D`""" def __init__(self, ctrl_to_base: nn.ModuleList): @@ -94,7 +94,7 @@ def __init__(self, ctrl_to_base: nn.ModuleList): self.ctrl_to_base = ctrl_to_base -def get_down_block_addon( +def get_down_block_adapter( base_in_channels: int, base_out_channels: int, ctrl_in_channels: int, @@ -170,7 +170,7 @@ def get_down_block_addon( else: downsamplers = None - down_block_components = ControlNetXSAddonDownBlockComponents( + down_block_components = DownBlockControlNetXSAdapter( resnets=nn.ModuleList(resnets), base_to_ctrl=nn.ModuleList(base_to_ctrl), ctrl_to_base=nn.ModuleList(ctrl_to_base), @@ -184,7 +184,7 @@ def get_down_block_addon( return down_block_components -def get_mid_block_addon( +def get_mid_block_adapter( base_channels: int, ctrl_channels: int, temb_channels: Optional[int] = None, @@ -215,10 +215,10 @@ def get_mid_block_addon( # Addition requires change in number of channels ctrl_to_base = make_zero_conv(ctrl_channels, base_channels) - return ControlNetXSAddonMidBlockComponents(base_to_ctrl=base_to_ctrl, midblock=midblock, ctrl_to_base=ctrl_to_base) + return MidBlockControlNetXSAdapter(base_to_ctrl=base_to_ctrl, midblock=midblock, ctrl_to_base=ctrl_to_base) -def get_up_block_addon( +def get_up_block_adapter( out_channels: int, prev_output_channel: int, ctrl_skip_channels: List[int], @@ -229,18 +229,18 @@ def get_up_block_addon( resnet_in_channels = prev_output_channel if i == 0 else out_channels ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels)) - return ControlNetXSAddonUpBlockComponents(ctrl_to_base=nn.ModuleList(ctrl_to_base)) + return UpBlockControlNetXSAdapter(ctrl_to_base=nn.ModuleList(ctrl_to_base)) -class ControlNetXSAddon(ModelMixin, ConfigMixin): +class ControlNetXSAdapter(ModelMixin, ConfigMixin): r""" - A `ControlNetXSAddon` model. To use it, pass it into a `ControlNetXSModel` (together with a `UNet2DConditionModel` - base model). + A `ControlNetXSAdapter` model. To use it, pass it into a `UNetControlNetXSModel` (together with a + `UNet2DConditionModel` base model). This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). - Like `ControlNetXSModel`, `ControlNetXSAddon` is compatible with StableDiffusion and StableDiffusion-XL. It's + Like `UNetControlNetXSModel`, `ControlNetXSAdapter` is compatible with StableDiffusion and StableDiffusion-XL. It's default parameters are compatible with StableDiffusion. Parameters: @@ -251,11 +251,12 @@ class ControlNetXSAddon(ModelMixin, ConfigMixin): conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): The tuple of output channels for each block in the `controlnet_cond_embedding` layer. time_embedding_mix (`float`, defaults to 1.0): - If 0, then only the control addon's time embedding is used. If 1, then only the base unet's time embedding - is used. Otherwise, both are combined. + If 0, then only the control adapters's time embedding is used. If 1, then only the base unet's time + embedding is used. Otherwise, both are combined. learn_time_embedding (`bool`, defaults to `False`): - Whether a time embedding should be learned. If yes, `ControlNetXSModel` will combine the time embeddings of - the base model and the addon. If no, `ControlNetXSModel` will use the base model's time embedding. + Whether a time embedding should be learned. If yes, `UNetControlNetXSModel` will combine the time + embeddings of the base model and the control adapter. If no, `UNetControlNetXSModel` will use the base + model's time embedding. num_attention_heads (`list[int]`, defaults to `[4]`): The number of attention heads. block_out_channels (`list[int]`, defaults to `[4, 8, 16, 16]`): @@ -319,7 +320,7 @@ def __init__( transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) if not isinstance(cross_attention_dim, (list, tuple)): cross_attention_dim = [cross_attention_dim] * len(down_block_types) - # see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why `ControlNetXSAddon` takes `num_attention_heads` instead of `attention_head_dim` + # see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why `ControlNetXSAdapter` takes `num_attention_heads` instead of `attention_head_dim` if not isinstance(num_attention_heads, (list, tuple)): num_attention_heads = [num_attention_heads] * len(down_block_types) @@ -360,7 +361,7 @@ def __init__( is_final_block = i == len(down_block_types) - 1 self.down_blocks.append( - get_down_block_addon( + get_down_block_adapter( base_in_channels=base_in_channels, base_out_channels=base_out_channels, ctrl_in_channels=ctrl_in_channels, @@ -377,7 +378,7 @@ def __init__( ) # mid - self.mid_block = get_mid_block_addon( + self.mid_block = get_mid_block_adapter( base_channels=base_block_out_channels[-1], ctrl_channels=block_out_channels[-1], temb_channels=time_embedding_dim, @@ -405,7 +406,7 @@ def __init__( ctrl_skip_channels_ = [ctrl_skip_channels.pop() for _ in range(3)] self.up_connections.append( - get_up_block_addon( + get_up_block_adapter( out_channels=base_out_channels, prev_output_channel=prev_base_output_channel, ctrl_skip_channels=ctrl_skip_channels_, @@ -426,11 +427,11 @@ def from_unet( conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), ): r""" - Instantiate a [`ControlNetXSAddon`] from a [`UNet2DConditionModel`]. + Instantiate a [`ControlNetXSAdapter`] from a [`UNet2DConditionModel`]. Parameters: unet (`UNet2DConditionModel`): - The UNet model we want to control. The dimensions of the ControlNetXSAddon will be adapted to it. + The UNet model we want to control. The dimensions of the ControlNetXSAdapter will be adapted to it. size_ratio (float, *optional*, defaults to `None`): When given, block_out_channels is set to a fraction of the base model's block_out_channels. Either this or `block_out_channels` must be given. @@ -440,9 +441,9 @@ def from_unet( The dimension of the attention heads. The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. learn_time_embedding (`bool`, defaults to `False`): - Whether the `ControlNetXSAddon` should learn a time embedding. + Whether the `ControlNetXSAdapter` should learn a time embedding. time_embedding_mix (`float`, defaults to 1.0): - If 0, then only the control addon's time embedding is used. If 1, then only the base unet's time + If 0, then only the control adapter's time embedding is used. If 1, then only the base unet's time embedding is used. Otherwise, both are combined. conditioning_channels (`int`, defaults to 3): Number of channels of conditioning input (e.g. an image) @@ -483,20 +484,20 @@ def from_unet( max_norm_num_groups=unet.config.norm_num_groups, ) - # ensure that the ControlNetXSAddon is the same dtype as the UNet2DConditionModel + # ensure that the ControlNetXSAdapter is the same dtype as the UNet2DConditionModel model.to(unet.dtype) return model def forward(self, *args, **kwargs): raise ValueError( - "A ControlNetXSAddonModel cannot be run by itself. Pass it into a ControlNetXSModel model instead." + "A ControlNetXSAdapter cannot be run by itself. Use it together with a UNet2DConditionModel to instantiate a UNetControlNetXSModel." ) class UNetControlNetXSModel(ModelMixin, ConfigMixin): r""" - A UNet fused with a ControlNet-XS addon model + A UNet fused with a ControlNet-XS adapter model This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). @@ -505,7 +506,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): compatible with StableDiffusion. It's parameters are either passed to the underlying `UNet2DConditionModel` or used exactly like in - `ControlNetXSAddon` . See their documentation for details. + `ControlNetXSAdapter` . See their documentation for details. """ _supports_gradient_checkpointing = True @@ -547,9 +548,7 @@ def __init__( if time_embedding_mix < 0 or time_embedding_mix > 1: raise ValueError("`time_embedding_mix` needs to be between 0 and 1.") if time_embedding_mix < 1 and not ctrl_learn_time_embedding: - raise ValueError( - "To use `time_embedding_mix` < 1, initialize `ctrl_addon` with `learn_time_embedding = True`" - ) + raise ValueError("To use `time_embedding_mix` < 1, `ctrl_learn_time_embedding` must be `True`") if addition_embed_type is not None and addition_embed_type != "text_time": raise ValueError( @@ -698,33 +697,36 @@ def __init__( def from_unet( cls, unet: UNet2DConditionModel, - controlnet: Optional[ControlNetXSAddon] = None, + controlnet: Optional[ControlNetXSAdapter] = None, size_ratio: Optional[float] = None, ctrl_block_out_channels: Optional[List[float]] = None, time_embedding_mix: Optional[float] = None, ctrl_optional_kwargs: Optional[Dict] = None, ): r""" - Instantiate a [`UNetControlNetXSModel`] from a [`UNet2DConditionModel`] and an optional [`ControlNetXSAddon`] . + Instantiate a [`UNetControlNetXSModel`] from a [`UNet2DConditionModel`] and an optional [`ControlNetXSAdapter`] + . Parameters: unet (`UNet2DConditionModel`): The UNet model we want to control. - controlnet (`ControlNetXSAddon`): - The ConntrolNet-XS addon with which the UNet will be fused. If none is given, a new ConntrolNet-XS - addon will be created. + controlnet (`ControlNetXSAdapter`): + The ConntrolNet-XS adapter with which the UNet will be fused. If none is given, a new ConntrolNet-XS + adapter will be created. size_ratio (float, *optional*, defaults to `None`): - Used to contruct the controlnet if none is given. See [`ControlNetXSAddon.from_unet`] for details. + Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details. ctrl_block_out_channels (`List[int]`, *optional*, defaults to `None`): - Used to contruct the controlnet if none is given. See [`ControlNetXSAddon.from_unet`] for details, + Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details, where this parameter is called `block_out_channels`. time_embedding_mix (`float`, *optional*, defaults to None): - Used to contruct the controlnet if none is given. See [`ControlNetXSAddon.from_unet`] for details. + Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details. ctrl_optional_kwargs (`Dict`, *optional*, defaults to `None`): Passed to the `init` of the new controlent if no controlent was given. """ if controlnet is None: - controlnet = ControlNetXSAddon.from_unet(unet, size_ratio, ctrl_block_out_channels, **ctrl_optional_kwargs) + controlnet = ControlNetXSAdapter.from_unet( + unet, size_ratio, ctrl_block_out_channels, **ctrl_optional_kwargs + ) else: if any( o is not None for o in (size_ratio, ctrl_block_out_channels, time_embedding_mix, ctrl_optional_kwargs) @@ -816,7 +818,7 @@ def freeze_unet_params(self) -> None: for param in self.parameters(): param.requires_grad = True - # Unfreeze ControlNetXSAddon + # Unfreeze ControlNetXSAdapter base_parts = [ "base_time_proj", "base_time_embedding", @@ -1296,7 +1298,7 @@ def __init__( self.gradient_checkpointing = False @classmethod - def from_modules(cls, base_downblock: CrossAttnDownBlock2D, ctrl_downblock: ControlNetXSAddonDownBlockComponents): + def from_modules(cls, base_downblock: CrossAttnDownBlock2D, ctrl_downblock: DownBlockControlNetXSAdapter): # get params def get_first_cross_attention(block): return block.attentions[0].transformer_blocks[0].attn2 @@ -1545,7 +1547,7 @@ def __init__( def from_modules( cls, base_midblock: UNetMidBlock2DCrossAttn, - ctrl_midblock: ControlNetXSAddonMidBlockComponents, + ctrl_midblock: MidBlockControlNetXSAdapter, ): base_to_ctrl = ctrl_midblock.base_to_ctrl ctrl_to_base = ctrl_midblock.ctrl_to_base @@ -1708,7 +1710,7 @@ def __init__( self.resolution_idx = resolution_idx @classmethod - def from_modules(cls, base_upblock: CrossAttnUpBlock2D, ctrl_upblock: ControlNetXSAddonUpBlockComponents): + def from_modules(cls, base_upblock: CrossAttnUpBlock2D, ctrl_upblock: UpBlockControlNetXSAdapter): ctrl_to_base_skip_connections = ctrl_upblock.ctrl_to_base # get params diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index 8ab5b74686b8..2f450b9c2cea 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -23,7 +23,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ControlNetXSAddon, UNet2DConditionModel, UNetControlNetXSModel +from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -47,7 +47,7 @@ Examples: ```py >>> # !pip install opencv-python transformers accelerate - >>> from diffusers import StableDiffusionControlNetXSPipeline, ControlNetXSAddon + >>> from diffusers import StableDiffusionControlNetXSPipeline, ControlNetXSAdapter >>> from diffusers.utils import load_image >>> import numpy as np >>> import torch @@ -66,7 +66,7 @@ >>> # initialize the models and pipeline >>> controlnet_conditioning_scale = 0.5 - >>> controlnet = ControlNetXSAddon.from_pretrained( + >>> controlnet = ControlNetXSAdapter.from_pretrained( ... "UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16 ... ) >>> pipe = StableDiffusionControlNetXSPipeline.from_pretrained( @@ -110,8 +110,10 @@ class StableDiffusionControlNetXSPipeline( Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. - controlnet ([`ControlNetXSModel`]): - A model containing a base UNet and a control addon. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetControlNetXSModel to denoise the encoded image latents. + controlnet ([`ControlNetXSAdapter`]): + A [`ControlNetXSAdapter`] to be used in combination with `unet` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. @@ -134,7 +136,7 @@ def __init__( text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: Union[UNet2DConditionModel, UNetControlNetXSModel], - controlnet: ControlNetXSAddon, + controlnet: ControlNetXSAdapter, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index 697b5a17364d..ff270d20d11e 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -30,7 +30,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ControlNetXSAddon, UNet2DConditionModel, UNetControlNetXSModel +from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -63,7 +63,7 @@ Examples: ```py >>> # !pip install opencv-python transformers accelerate - >>> from diffusers import StableDiffusionXLControlNetXSPipeline, ControlNetXSAddon, AutoencoderKL + >>> from diffusers import StableDiffusionXLControlNetXSPipeline, ControlNetXSAdapter, AutoencoderKL >>> from diffusers.utils import load_image >>> import numpy as np >>> import torch @@ -82,7 +82,7 @@ >>> # initialize the models and pipeline >>> controlnet_conditioning_scale = 0.5 >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) - >>> controlnet = ControlNetXSAddon.from_pretrained( + >>> controlnet = ControlNetXSAdapter.from_pretrained( ... "UmerHA/Testing-ConrolNetXS-SDXL-canny", torch_dtype=torch.float16 ... ) >>> pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( @@ -135,9 +135,9 @@ class StableDiffusionXLControlNetXSPipeline( tokenizer_2 ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): - A `UNet2DConditionModel` to denoise the encoded image latents. - controlnet ([`ControlNetXSModel`]: - Provides additional conditioning to the `unet` during the denoising process. + A [`UNet2DConditionModel`] used to create a UNetControlNetXSModel to denoise the encoded image latents. + controlnet ([`ControlNetXSAdapter`]): + A [`ControlNetXSAdapter`] to be used in combination with `unet` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. @@ -168,7 +168,7 @@ def __init__( tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: Union[UNet2DConditionModel, UNetControlNetXSModel], - controlnet: ControlNetXSAddon, + controlnet: ControlNetXSAdapter, scheduler: KarrasDiffusionSchedulers, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 3db200e65336..b04006cb5ee6 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -92,7 +92,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class ControlNetXSAddon(metaclass=DummyObject): +class ControlNetXSAdapter(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): diff --git a/tests/models/unets/test_models_unet_controlnetxs.py b/tests/models/unets/test_models_unet_controlnetxs.py index b41c07508aa1..8c9b43a20ad6 100644 --- a/tests/models/unets/test_models_unet_controlnetxs.py +++ b/tests/models/unets/test_models_unet_controlnetxs.py @@ -20,7 +20,7 @@ import torch from torch import nn -from diffusers import ControlNetXSAddon, UNet2DConditionModel, UNetControlNetXSModel +from diffusers import ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel from diffusers.utils import logging from diffusers.utils.testing_utils import ( enable_full_determinism, @@ -89,7 +89,7 @@ def prepare_init_args_and_inputs_for_common(self): return init_dict, inputs_dict def get_dummy_unet(self): - """For some tests we also need the underlying UNet. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Addon""" + """For some tests we also need the underlying UNet. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter""" return UNet2DConditionModel( block_out_channels=(4, 8), layers_per_block=2, @@ -104,9 +104,9 @@ def get_dummy_unet(self): ) def get_dummy_controlnet_from_unet(self, unet, **kwargs): - """For some tests we also need the underlying ControlNetXS-Addon. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Addon""" + """For some tests we also need the underlying ControlNetXS-Adapter. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter""" # size_ratio and conditioning_embedding_out_channels chosen to keep model small - return ControlNetXSAddon.from_unet(unet, size_ratio=1, conditioning_embedding_out_channels=(2, 2), **kwargs) + return ControlNetXSAdapter.from_unet(unet, size_ratio=1, conditioning_embedding_out_channels=(2, 2), **kwargs) def test_from_unet(self): unet = self.get_dummy_unet() diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py index a619f7d07a97..5ac78129ef34 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py @@ -26,7 +26,7 @@ AutoencoderKL, AutoencoderTiny, ConsistencyDecoderVAE, - ControlNetXSAddon, + ControlNetXSAdapter, DDIMScheduler, LCMScheduler, StableDiffusionControlNetXSPipeline, @@ -75,7 +75,7 @@ def _test_stable_diffusion_compile(in_queue, out_queue, timeout): try: _ = in_queue.get(timeout=timeout) - controlnet = ControlNetXSAddon.from_pretrained( + controlnet = ControlNetXSAdapter.from_pretrained( "UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16 ) pipe = StableDiffusionControlNetXSPipeline.from_pretrained( @@ -147,7 +147,7 @@ def get_dummy_components(self, time_cond_proj_dim=None): use_linear_projection=True, ) torch.manual_seed(0) - controlnet = ControlNetXSAddon.from_unet( + controlnet = ControlNetXSAdapter.from_unet( unet=unet, size_ratio=1, learn_time_embedding=True, @@ -309,7 +309,7 @@ def tearDown(self): torch.cuda.empty_cache() def test_canny(self): - controlnet = ControlNetXSAddon.from_pretrained( + controlnet = ControlNetXSAdapter.from_pretrained( "UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16 ) pipe = StableDiffusionControlNetXSPipeline.from_pretrained( @@ -335,7 +335,7 @@ def test_canny(self): assert np.allclose(original_image, expected_image, atol=1e-04) def test_depth(self): - controlnet = ControlNetXSAddon.from_pretrained( + controlnet = ControlNetXSAdapter.from_pretrained( "UmerHA/Testing-ConrolNetXS-SD2.1-depth", torch_dtype=torch.float16 ) pipe = StableDiffusionControlNetXSPipeline.from_pretrained( diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py index 4d3d5071f13d..ee0d15ec3472 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py @@ -25,7 +25,7 @@ AutoencoderKL, AutoencoderTiny, ConsistencyDecoderVAE, - ControlNetXSAddon, + ControlNetXSAdapter, EulerDiscreteScheduler, StableDiffusionXLControlNetXSPipeline, UNet2DConditionModel, @@ -93,7 +93,7 @@ def get_dummy_components(self): cross_attention_dim=8, ) torch.manual_seed(0) - controlnet = ControlNetXSAddon.from_unet( + controlnet = ControlNetXSAdapter.from_unet( unet=unet, size_ratio=0.5, learn_time_embedding=True, @@ -377,7 +377,7 @@ def tearDown(self): torch.cuda.empty_cache() def test_canny(self): - controlnet = ControlNetXSAddon.from_pretrained( + controlnet = ControlNetXSAdapter.from_pretrained( "UmerHA/Testing-ConrolNetXS-SDXL-canny", torch_dtype=torch.float16 ) pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( @@ -401,7 +401,7 @@ def test_canny(self): assert np.allclose(original_image, expected_image, atol=1e-04) def test_depth(self): - controlnet = ControlNetXSAddon.from_pretrained( + controlnet = ControlNetXSAdapter.from_pretrained( "UmerHA/Testing-ConrolNetXS-SDXL-depth", torch_dtype=torch.float16 ) pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( From 7782b58a80e3a6da14b9cf8844afd1b88faedd13 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Thu, 11 Apr 2024 22:01:58 +0200 Subject: [PATCH 74/75] Fixed save_pretrained/from_pretrained bug --- src/diffusers/models/controlnet_xs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 7bb1c4e7977c..4cc1cb68315b 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -677,9 +677,9 @@ def __init__( temb_channels=time_embed_dim, resolution_idx=i, has_crossattn=has_crossattn, - transformer_layers_per_block=rev_transformer_layers_per_block[-1], - num_attention_heads=rev_num_attention_heads[-1], - cross_attention_dim=rev_cross_attention_dim[-1], + transformer_layers_per_block=rev_transformer_layers_per_block[i], + num_attention_heads=rev_num_attention_heads[i], + cross_attention_dim=rev_cross_attention_dim[i], add_upsample=not is_final_block, upcast_attention=upcast_attention, norm_num_groups=norm_num_groups, From b5815cc4911dc638df484504f92a336af3a791e6 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Fri, 12 Apr 2024 18:25:41 +0200 Subject: [PATCH 75/75] Removed redundant code --- src/diffusers/models/controlnet_xs.py | 33 ++------------------------- 1 file changed, 2 insertions(+), 31 deletions(-) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 4cc1cb68315b..4bbe1dd4dc25 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -1807,17 +1807,6 @@ def forward( and getattr(self, "b2", None) ) - # In ControlNet-XS, the last resnet/attention and the upsampler are treated together as one group. - # So we separate them to pass information from ctrl to base correctly. - if self.upsamplers is None: - resnets_without_upsampler = self.resnets - attn_without_upsampler = self.attentions - else: - resnets_without_upsampler = self.resnets[:-1] - attn_without_upsampler = self.attentions[:-1] - resnet_with_upsampler = self.resnets[-1] - attn_with_upsampler = self.attentions[-1] - def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: @@ -1843,8 +1832,8 @@ def maybe_apply_freeu_to_subblock(hidden_states, res_h_base): return hidden_states, res_h_base for resnet, attn, c2b, res_h_base, res_h_ctrl in zip( - resnets_without_upsampler, - attn_without_upsampler, + self.resnets, + self.attentions, self.ctrl_to_base, reversed(res_hidden_states_tuple_base), reversed(res_hidden_states_tuple_ctrl), @@ -1877,24 +1866,6 @@ def maybe_apply_freeu_to_subblock(hidden_states, res_h_base): )[0] if self.upsamplers is not None: - c2b = self.ctrl_to_base[-1] - res_h_base = res_hidden_states_tuple_base[0] - res_h_ctrl = res_hidden_states_tuple_ctrl[0] - if apply_control: - hidden_states += c2b(res_h_ctrl) * conditioning_scale - hidden_states, res_h_base = maybe_apply_freeu_to_subblock(hidden_states, res_h_base) - hidden_states = torch.cat([hidden_states, res_h_base], dim=1) - - hidden_states = resnet_with_upsampler(hidden_states, temb) - if attn_with_upsampler is not None: - hidden_states = attn_with_upsampler( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] hidden_states = self.upsamplers(hidden_states, upsample_size) return hidden_states