diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 13687dc5419e..bf23d363a863 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -112,6 +112,8 @@ title: "Stochastic Karras VE" - local: api/pipelines/dance_diffusion title: "Dance Diffusion" + - local: api/pipelines/versatile_diffusion + title: "Versatile Diffusion" - local: api/pipelines/vq_diffusion title: "VQ Diffusion" - local: api/pipelines/repaint diff --git a/docs/source/api/pipelines/overview.mdx b/docs/source/api/pipelines/overview.mdx index ff83bad05525..c43f09d66dde 100644 --- a/docs/source/api/pipelines/overview.mdx +++ b/docs/source/api/pipelines/overview.mdx @@ -60,6 +60,9 @@ available a colab notebook to directly try them out. | [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) | [stable_diffusion_safe](./api/pipelines/stable_diffusion_safe) | [**Safe Stable Diffusion**](https://arxiv.org/abs/2211.05105) | Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb) | [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation | +| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation | +| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation | +| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation | | [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation | diff --git a/docs/source/api/pipelines/versatile_diffusion.mdx b/docs/source/api/pipelines/versatile_diffusion.mdx new file mode 100644 index 000000000000..f557c5b0aac8 --- /dev/null +++ b/docs/source/api/pipelines/versatile_diffusion.mdx @@ -0,0 +1,73 @@ + + +# VersatileDiffusion + +VersatileDiffusion was proposed in [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) by Xingqian Xu, Zhangyang Wang, Eric Zhang, Kai Wang, Humphrey Shi . + +The abstract of the paper is the following: + +*The recent advances in diffusion models have set an impressive milestone in many generation tasks. Trending works such as DALL-E2, Imagen, and Stable Diffusion have attracted great interest in academia and industry. Despite the rapid landscape changes, recent new approaches focus on extensions and performance rather than capacity, thus requiring separate models for separate tasks. In this work, we expand the existing single-flow diffusion pipeline into a multi-flow network, dubbed Versatile Diffusion (VD), that handles text-to-image, image-to-text, image-variation, and text-variation in one unified model. Moreover, we generalize VD to a unified multi-flow multimodal diffusion framework with grouped layers, swappable streams, and other propositions that can process modalities beyond images and text. Through our experiments, we demonstrate that VD and its underlying framework have the following merits: a) VD handles all subtasks with competitive quality; b) VD initiates novel extensions and applications such as disentanglement of style and semantic, image-text dual-guided generation, etc.; c) Through these experiments and applications, VD provides more semantic insights of the generated outputs.* + +## Tips + +- VersatileDiffusion is conceptually very similar as [Stable Diffusion](./api/pipelines/stable_diffusion), but instead of providing just a image data stream conditioned on text, VersatileDiffusion provides both a image and text data stream and can be conditioned on both text and image. + +### *Run VersatileDiffusion* + +You can both load the memory intensive "all-in-one" [`VersatileDiffusionPipeline`] that can run all tasks +with the same class as shown in [`VersatileDiffusionPipeline.text_to_image`], [`VersatileDiffusionPipeline.image_variation`], and [`VersatileDiffusionPipeline.dual_guided`] + +**or** + +You can run the individual pipelines which are much more memory efficient: + +- *Text-to-Image*: [`VersatileDiffusionTextToImagePipeline.__call__`] +- *Image Variation*: [`VersatileDiffusionImageVariationPipeline.__call__`] +- *Dual Text and Image Guided Generation*: [`VersatileDiffusionDualGuidedPipeline.__call__`] + +### *How to load and use different schedulers.* + +The versatile diffusion pipelines uses [`DDIMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the alt diffusion pipeline such as [`PNDMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc. +To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following: + +```python +>>> from diffusers import VersatileDiffusionPipeline, EulerDiscreteScheduler + +>>> pipeline = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion") +>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config) + +>>> # or +>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("shi-labs/versatile-diffusion", subfolder="scheduler") +>>> pipeline = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion", scheduler=euler_scheduler) +``` + +## VersatileDiffusionPipeline +[[autodoc]] VersatileDiffusionPipeline + +## VersatileDiffusionTextToImagePipeline +[[autodoc]] VersatileDiffusionTextToImagePipeline + - __call__ + - enable_attention_slicing + - disable_attention_slicing + +## VersatileDiffusionImageVariationPipeline +[[autodoc]] VersatileDiffusionImageVariationPipeline + - __call__ + - enable_attention_slicing + - disable_attention_slicing + +## VersatileDiffusionDualGuidedPipeline +[[autodoc]] VersatileDiffusionDualGuidedPipeline + - __call__ + - enable_attention_slicing + - disable_attention_slicing diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 1c5ecc5fe3cb..09cc59fda99c 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -50,6 +50,9 @@ available a colab notebook to directly try them out. | [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) | [stable_diffusion_safe](./api/pipelines/stable_diffusion_safe) | [**Safe Stable Diffusion**](https://arxiv.org/abs/2211.05105) | Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb) | [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation | +| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation | +| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation | +| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation | | [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation | **Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers. diff --git a/scripts/convert_versatile_diffusion_to_diffusers.py b/scripts/convert_versatile_diffusion_to_diffusers.py new file mode 100644 index 000000000000..86fb0e7b4c97 --- /dev/null +++ b/scripts/convert_versatile_diffusion_to_diffusers.py @@ -0,0 +1,791 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Conversion script for the Versatile Stable Diffusion checkpoints. """ + +import argparse +from argparse import Namespace + +import torch + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + UNet2DConditionModel, + VersatileDiffusionPipeline, +) +from diffusers.pipelines.versatile_diffusion.modeling_text_unet import UNetFlatConditionModel +from transformers import ( + CLIPFeatureExtractor, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + + +SCHEDULER_CONFIG = Namespace( + **{ + "beta_linear_start": 0.00085, + "beta_linear_end": 0.012, + "timesteps": 1000, + "scale_factor": 0.18215, + } +) + +IMAGE_UNET_CONFIG = Namespace( + **{ + "input_channels": 4, + "model_channels": 320, + "output_channels": 4, + "num_noattn_blocks": [2, 2, 2, 2], + "channel_mult": [1, 2, 4, 4], + "with_attn": [True, True, True, False], + "num_heads": 8, + "context_dim": 768, + "use_checkpoint": True, + } +) + +TEXT_UNET_CONFIG = Namespace( + **{ + "input_channels": 768, + "model_channels": 320, + "output_channels": 768, + "num_noattn_blocks": [2, 2, 2, 2], + "channel_mult": [1, 2, 4, 4], + "second_dim": [4, 4, 4, 4], + "with_attn": [True, True, True, False], + "num_heads": 8, + "context_dim": 768, + "use_checkpoint": True, + } +) + +AUTOENCODER_CONFIG = Namespace( + **{ + "double_z": True, + "z_channels": 4, + "resolution": 256, + "in_channels": 3, + "out_ch": 3, + "ch": 128, + "ch_mult": [1, 2, 4, 4], + "num_res_blocks": 2, + "attn_resolutions": [], + "dropout": 0.0, + } +) + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "query.weight") + new_item = new_item.replace("q.bias", "query.bias") + + new_item = new_item.replace("k.weight", "key.weight") + new_item = new_item.replace("k.bias", "key.bias") + + new_item = new_item.replace("v.weight", "value.weight") + new_item = new_item.replace("v.bias", "value.bias") + + new_item = new_item.replace("proj_out.weight", "proj_attn.weight") + new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming + to them. It splits attention layers, and takes into account additional replacements + that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + if "proj_attn.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + elif path["old"] in old_checkpoint: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def create_image_unet_diffusers_config(unet_params): + """ + Creates a config for the diffusers based on the config of the VD model. + """ + + block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if unet_params.with_attn[i] else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if unet_params.with_attn[-i - 1] else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + if not all(n == unet_params.num_noattn_blocks[0] for n in unet_params.num_noattn_blocks): + raise ValueError("Not all num_res_blocks are equal, which is not supported in this script.") + + config = dict( + sample_size=None, + in_channels=unet_params.input_channels, + out_channels=unet_params.output_channels, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + layers_per_block=unet_params.num_noattn_blocks[0], + cross_attention_dim=unet_params.context_dim, + attention_head_dim=unet_params.num_heads, + ) + + return config + + +def create_text_unet_diffusers_config(unet_params): + """ + Creates a config for the diffusers based on the config of the VD model. + """ + + block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlockFlat" if unet_params.with_attn[i] else "DownBlockFlat" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlockFlat" if unet_params.with_attn[-i - 1] else "UpBlockFlat" + up_block_types.append(block_type) + resolution //= 2 + + if not all(n == unet_params.num_noattn_blocks[0] for n in unet_params.num_noattn_blocks): + raise ValueError("Not all num_res_blocks are equal, which is not supported in this script.") + + config = dict( + sample_size=None, + in_channels=(unet_params.input_channels, 1, 1), + out_channels=(unet_params.output_channels, 1, 1), + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + layers_per_block=unet_params.num_noattn_blocks[0], + cross_attention_dim=unet_params.context_dim, + attention_head_dim=unet_params.num_heads, + ) + + return config + + +def create_vae_diffusers_config(vae_params): + """ + Creates a config for the diffusers based on the config of the VD model. + """ + + block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = dict( + sample_size=vae_params.resolution, + in_channels=vae_params.in_channels, + out_channels=vae_params.out_ch, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + latent_channels=vae_params.z_channels, + layers_per_block=vae_params.num_res_blocks, + ) + return config + + +def create_diffusers_scheduler(original_config): + schedular = DDIMScheduler( + num_train_timesteps=original_config.model.params.timesteps, + beta_start=original_config.model.params.linear_start, + beta_end=original_config.model.params.linear_end, + beta_schedule="scaled_linear", + ) + return schedular + + +def convert_vd_unet_checkpoint(checkpoint, config, unet_key, extract_ema=False): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100: + print("Checkpoint has both EMA and non-EMA weights.") + if extract_ema: + print( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + print( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["model.diffusion_model.time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["model.diffusion_model.time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["model.diffusion_model.time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["model.diffusion_model.time_embed.2.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + elif f"input_blocks.{i}.0.weight" in unet_state_dict: + # text_unet uses linear layers in place of downsamplers + shape = unet_state_dict[f"input_blocks.{i}.0.weight"].shape + if shape[0] != shape[1]: + continue + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if ["conv.weight", "conv.bias"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + elif f"output_blocks.{i}.1.weight" in unet_state_dict: + # text_unet uses linear layers in place of upsamplers + shape = unet_state_dict[f"output_blocks.{i}.1.weight"].shape + if shape[0] != shape[1]: + continue + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.weight"] = unet_state_dict.pop( + f"output_blocks.{i}.1.weight" + ) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.bias"] = unet_state_dict.pop( + f"output_blocks.{i}.1.bias" + ) + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + elif f"output_blocks.{i}.2.weight" in unet_state_dict: + # text_unet uses linear layers in place of upsamplers + shape = unet_state_dict[f"output_blocks.{i}.2.weight"].shape + if shape[0] != shape[1]: + continue + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.weight"] = unet_state_dict.pop( + f"output_blocks.{i}.2.weight" + ) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.bias"] = unet_state_dict.pop( + f"output_blocks.{i}.2.bias" + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + return new_checkpoint + + +def convert_vd_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + keys = list(checkpoint.keys()) + for key in keys: + vae_state_dict[key] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--unet_checkpoint_path", default=None, type=str, required=False, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--vae_checkpoint_path", default=None, type=str, required=False, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--optimus_checkpoint_path", default=None, type=str, required=False, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--scheduler_type", + default="pndm", + type=str, + help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancest', 'dpm']", + ) + parser.add_argument( + "--extract_ema", + action="store_true", + help=( + "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights" + " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield" + " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." + ), + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + + args = parser.parse_args() + + scheduler_config = SCHEDULER_CONFIG + + num_train_timesteps = scheduler_config.timesteps + beta_start = scheduler_config.beta_linear_start + beta_end = scheduler_config.beta_linear_end + if args.scheduler_type == "pndm": + scheduler = PNDMScheduler( + beta_end=beta_end, + beta_schedule="scaled_linear", + beta_start=beta_start, + num_train_timesteps=num_train_timesteps, + skip_prk_steps=True, + steps_offset=1, + ) + elif args.scheduler_type == "lms": + scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear") + elif args.scheduler_type == "euler": + scheduler = EulerDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear") + elif args.scheduler_type == "euler-ancestral": + scheduler = EulerAncestralDiscreteScheduler( + beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear" + ) + elif args.scheduler_type == "dpm": + scheduler = DPMSolverMultistepScheduler( + beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear" + ) + elif args.scheduler_type == "ddim": + scheduler = DDIMScheduler( + beta_start=beta_start, + beta_end=beta_end, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + steps_offset=1, + ) + else: + raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!") + + # Convert the UNet2DConditionModel models. + if args.unet_checkpoint_path is not None: + # image UNet + image_unet_config = create_image_unet_diffusers_config(IMAGE_UNET_CONFIG) + checkpoint = torch.load(args.unet_checkpoint_path) + converted_image_unet_checkpoint = convert_vd_unet_checkpoint( + checkpoint, image_unet_config, unet_key="model.diffusion_model.unet_image.", extract_ema=args.extract_ema + ) + image_unet = UNet2DConditionModel(**image_unet_config) + image_unet.load_state_dict(converted_image_unet_checkpoint) + + # text UNet + text_unet_config = create_text_unet_diffusers_config(TEXT_UNET_CONFIG) + converted_text_unet_checkpoint = convert_vd_unet_checkpoint( + checkpoint, text_unet_config, unet_key="model.diffusion_model.unet_text.", extract_ema=args.extract_ema + ) + text_unet = UNetFlatConditionModel(**text_unet_config) + text_unet.load_state_dict(converted_text_unet_checkpoint) + + # Convert the VAE model. + if args.vae_checkpoint_path is not None: + vae_config = create_vae_diffusers_config(AUTOENCODER_CONFIG) + checkpoint = torch.load(args.vae_checkpoint_path) + converted_vae_checkpoint = convert_vd_vae_checkpoint(checkpoint, vae_config) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + image_feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-large-patch14") + text_encoder = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") + image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") + + pipe = VersatileDiffusionPipeline( + scheduler=scheduler, + tokenizer=tokenizer, + image_feature_extractor=image_feature_extractor, + text_encoder=text_encoder, + image_encoder=image_encoder, + image_unet=image_unet, + text_unet=text_unet, + vae=vae, + ) + pipe.save_pretrained(args.dump_path) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 9669382bf69f..c4052be34e1e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -75,6 +75,10 @@ StableDiffusionInpaintPipelineLegacy, StableDiffusionPipeline, StableDiffusionPipelineSafe, + VersatileDiffusionDualGuidedPipeline, + VersatileDiffusionImageVariationPipeline, + VersatileDiffusionPipeline, + VersatileDiffusionTextToImagePipeline, VQDiffusionPipeline, ) else: diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 7e11bde27347..6b2bd5205bdb 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -699,3 +699,129 @@ def forward(self, x, timestep): scale, shift = torch.chunk(emb, 2) x = self.norm(x) * (1 + scale) + shift return x + + +class DualTransformer2DModel(nn.Module): + """ + Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of context dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + """ + + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + ): + super().__init__() + self.transformers = nn.ModuleList( + [ + Transformer2DModel( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + in_channels=in_channels, + num_layers=num_layers, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_bias=attention_bias, + sample_size=sample_size, + num_vector_embeds=num_vector_embeds, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + ) + for _ in range(2) + ] + ) + + # Variables that can be set by a pipeline: + + # The ratio of transformer1 to transformer2's output states to be combined during inference + self.mix_ratio = 0.5 + + # The shape of `encoder_hidden_states` is expected to be + # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)` + self.condition_lengths = [77, 257] + + # Which transformer to use to encode which condition. + # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])` + self.transformer_index_for_condition = [1, 0] + + def forward(self, hidden_states, encoder_hidden_states, timestep=None, return_dict: bool = True): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`] + if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample + tensor. + """ + input_states = hidden_states + + encoded_states = [] + tokens_start = 0 + for i in range(2): + # for each of the two transformers, pass the corresponding condition tokens + condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]] + transformer_index = self.transformer_index_for_condition[i] + encoded_state = self.transformers[transformer_index](input_states, condition_state, timestep, return_dict)[ + 0 + ] + encoded_states.append(encoded_state - input_states) + tokens_start += self.condition_lengths[i] + + output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio) + output_states = output_states + input_states + + if not return_dict: + return (output_states,) + + return Transformer2DModelOutput(sample=output_states) + + def _set_attention_slice(self, slice_size): + for transformer in self.transformers: + transformer._set_attention_slice(slice_size) + + def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for transformer in self.transformers: + transformer._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 6d28f07b18a6..5b337f482cc5 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -175,7 +175,7 @@ def __init__( num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps) self.conv_act = nn.SiLU() - self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) def forward( self, diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 770043f053b2..4dd15845e04f 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -15,7 +15,7 @@ import torch from torch import nn -from .attention import AttentionBlock, Transformer2DModel +from .attention import AttentionBlock, DualTransformer2DModel, Transformer2DModel from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D @@ -32,6 +32,7 @@ def get_down_block( resnet_groups=None, cross_attention_dim=None, downsample_padding=None, + dual_cross_attention=False, ): down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlock2D": @@ -74,6 +75,7 @@ def get_down_block( downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, ) elif down_block_type == "SkipDownBlock2D": return SkipDownBlock2D( @@ -137,6 +139,7 @@ def get_up_block( attn_num_head_channels, resnet_groups=None, cross_attention_dim=None, + dual_cross_attention=False, ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock2D": @@ -166,6 +169,7 @@ def get_up_block( resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, ) elif up_block_type == "AttnUpBlock2D": return AttnUpBlock2D( @@ -322,6 +326,7 @@ def __init__( attention_type="default", output_scale_factor=1.0, cross_attention_dim=1280, + dual_cross_attention=False, **kwargs, ): super().__init__() @@ -348,16 +353,28 @@ def __init__( attentions = [] for _ in range(num_layers): - attentions.append( - Transformer2DModel( - attn_num_head_channels, - in_channels // attn_num_head_channels, - in_channels=in_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) ) - ) resnets.append( ResnetBlock2D( in_channels=in_channels, @@ -505,6 +522,7 @@ def __init__( output_scale_factor=1.0, downsample_padding=1, add_downsample=True, + dual_cross_attention=False, ): super().__init__() resnets = [] @@ -529,16 +547,28 @@ def __init__( pre_norm=resnet_pre_norm, ) ) - attentions.append( - Transformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) ) - ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) @@ -1089,6 +1119,7 @@ def __init__( attention_type="default", output_scale_factor=1.0, add_upsample=True, + dual_cross_attention=False, ): super().__init__() resnets = [] @@ -1115,16 +1146,28 @@ def __init__( pre_norm=resnet_pre_norm, ) ) - attentions.append( - Transformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) ) - ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index b09044b57b7b..4eaed803ceee 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -107,6 +107,7 @@ def __init__( norm_eps: float = 1e-5, cross_attention_dim: int = 1280, attention_head_dim: int = 8, + dual_cross_attention: bool = False, ): super().__init__() @@ -146,6 +147,7 @@ def __init__( cross_attention_dim=cross_attention_dim, attn_num_head_channels=attention_head_dim, downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, ) self.down_blocks.append(down_block) @@ -160,6 +162,7 @@ def __init__( cross_attention_dim=cross_attention_dim, attn_num_head_channels=attention_head_dim, resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, ) # count how many layers upsample the images @@ -195,6 +198,7 @@ def __init__( resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attention_head_dim, + dual_cross_attention=dual_cross_attention, ) self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -202,7 +206,7 @@ def __init__( # out self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) self.conv_act = nn.SiLU() - self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) def set_attention_slice(self, slice_size): if slice_size is not None and self.config.attention_head_dim % slice_size != 0: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index a27cb5d2076a..9f4cef4b73e6 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -26,6 +26,12 @@ StableDiffusionPipeline, ) from .stable_diffusion_safe import StableDiffusionPipelineSafe + from .versatile_diffusion import ( + VersatileDiffusionDualGuidedPipeline, + VersatileDiffusionImageVariationPipeline, + VersatileDiffusionPipeline, + VersatileDiffusionTextToImagePipeline, + ) from .vq_diffusion import VQDiffusionPipeline if is_transformers_available() and is_onnx_available(): diff --git a/src/diffusers/pipelines/versatile_diffusion/__init__.py b/src/diffusers/pipelines/versatile_diffusion/__init__.py new file mode 100644 index 000000000000..65bc1b72007a --- /dev/null +++ b/src/diffusers/pipelines/versatile_diffusion/__init__.py @@ -0,0 +1,9 @@ +from ...utils import is_torch_available, is_transformers_available + + +if is_transformers_available() and is_torch_available(): + from .modeling_text_unet import UNetFlatConditionModel + from .pipeline_versatile_diffusion import VersatileDiffusionPipeline + from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline + from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline + from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py new file mode 100644 index 000000000000..c89080a59e9d --- /dev/null +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -0,0 +1,1082 @@ +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...modeling_utils import ModelMixin +from ...models.attention import DualTransformer2DModel, Transformer2DModel +from ...models.embeddings import TimestepEmbedding, Timesteps +from ...models.unet_2d_condition import UNet2DConditionOutput +from ...utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=None, +): + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlockFlat": + return DownBlockFlat( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + ) + elif down_block_type == "CrossAttnDownBlockFlat": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockFlat") + return CrossAttnDownBlockFlat( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + ) + raise ValueError(f"{down_block_type} is not supported.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=None, +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlockFlat": + return UpBlockFlat( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + ) + elif up_block_type == "CrossAttnUpBlockFlat": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockFlat") + return CrossAttnUpBlockFlat( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + ) + raise ValueError(f"{up_block_type} is not supported.") + + +# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat +class UNetFlatConditionModel(ModelMixin, ConfigMixin): + r""" + UNetFlatConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a + timestep and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the models (such as downloading or saving, etc.) + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat",)`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlockFlat", + "CrossAttnDownBlockFlat", + "CrossAttnDownBlockFlat", + "DownBlockFlat", + ), + up_block_types: Tuple[str] = ( + "UpBlockFlat", + "CrossAttnUpBlockFlat", + "CrossAttnUpBlockFlat", + "CrossAttnUpBlockFlat", + ), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: int = 8, + dual_cross_attention: bool = False, + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = LinearMultiDim(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim, + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlockFlatCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift="default", + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim, + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim, + dual_cross_attention=dual_cross_attention, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = LinearMultiDim(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + def set_attention_slice(self, slice_size): + if slice_size is not None and self.config.attention_head_dim % slice_size != 0: + raise ValueError( + f"Make sure slice_size {slice_size} is a divisor of " + f"the number of heads used in cross_attention {self.config.attention_head_dim}" + ) + if slice_size is not None and slice_size > self.config.attention_head_dim: + raise ValueError( + f"Chunk_size {slice_size} has to be smaller or equal to " + f"the number of heads used in cross_attention {self.config.attention_head_dim}" + ) + + for block in self.down_blocks: + if hasattr(block, "attentions") and block.attentions is not None: + block.set_attention_slice(slice_size) + + self.mid_block.set_attention_slice(slice_size) + + for block in self.up_blocks: + if hasattr(block, "attentions") and block.attentions is not None: + block.set_attention_slice(slice_size) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for block in self.down_blocks: + if hasattr(block, "attentions") and block.attentions is not None: + block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + for block in self.up_blocks: + if hasattr(block, "attentions") and block.attentions is not None: + block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlockFlat, DownBlockFlat, CrossAttnUpBlockFlat, UpBlockFlat)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + return_dict: bool = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): + (batch_size, sequence_length, hidden_size) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) + + +class LinearMultiDim(nn.Linear): + def __init__(self, in_features, out_features=None, second_dim=4, *args, **kwargs): + in_features = [in_features, second_dim, 1] if isinstance(in_features, int) else list(in_features) + if out_features is None: + out_features = in_features + out_features = [out_features, second_dim, 1] if isinstance(out_features, int) else list(out_features) + self.in_features_multidim = in_features + self.out_features_multidim = out_features + super().__init__(np.array(in_features).prod(), np.array(out_features).prod()) + + def forward(self, input_tensor, *args, **kwargs): + shape = input_tensor.shape + n_dim = len(self.in_features_multidim) + input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_features) + output_tensor = super().forward(input_tensor) + output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_features_multidim) + return output_tensor + + +class ResnetBlockFlat(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + time_embedding_norm="default", + use_in_shortcut=None, + second_dim=4, + **kwargs, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + + in_channels = [in_channels, second_dim, 1] if isinstance(in_channels, int) else list(in_channels) + self.in_channels_prod = np.array(in_channels).prod() + self.channels_multidim = in_channels + + if out_channels is not None: + out_channels = [out_channels, second_dim, 1] if isinstance(out_channels, int) else list(out_channels) + out_channels_prod = np.array(out_channels).prod() + self.out_channels_multidim = out_channels + else: + out_channels_prod = self.in_channels_prod + self.out_channels_multidim = self.channels_multidim + self.time_embedding_norm = time_embedding_norm + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=self.in_channels_prod, eps=eps, affine=True) + self.conv1 = torch.nn.Conv2d(self.in_channels_prod, out_channels_prod, kernel_size=1, padding=0) + + if temb_channels is not None: + self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels_prod) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels_prod, eps=eps, affine=True) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels_prod, out_channels_prod, kernel_size=1, padding=0) + + self.nonlinearity = nn.SiLU() + + self.use_in_shortcut = ( + self.in_channels_prod != out_channels_prod if use_in_shortcut is None else use_in_shortcut + ) + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + self.in_channels_prod, out_channels_prod, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, input_tensor, temb): + shape = input_tensor.shape + n_dim = len(self.channels_multidim) + input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_channels_prod, 1, 1) + input_tensor = input_tensor.view(-1, self.in_channels_prod, 1, 1) + + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + output_tensor = output_tensor.view(*shape[0:-n_dim], -1) + output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_channels_multidim) + + return output_tensor + + +# Copied from diffusers.models.unet_2d_blocks.DownBlock2D with DownBlock2D->DownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim +class DownBlockFlat(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + LinearMultiDim( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +# Copied from diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D with CrossAttnDownBlock2D->CrossAttnDownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim +class CrossAttnDownBlockFlat(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + attention_type="default", + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.attention_type = attention_type + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + LinearMultiDim( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def set_attention_slice(self, slice_size): + if slice_size is not None and self.attn_num_head_channels % slice_size != 0: + raise ValueError( + f"Make sure slice_size {slice_size} is a divisor of " + f"the number of heads used in cross_attention {self.attn_num_head_channels}" + ) + if slice_size is not None and slice_size > self.attn_num_head_channels: + raise ValueError( + f"Chunk_size {slice_size} has to be smaller or equal to " + f"the number of heads used in cross_attention {self.attn_num_head_channels}" + ) + + for attn in self.attentions: + attn._set_attention_slice(slice_size) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for attn in self.attentions: + attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +# Copied from diffusers.models.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim +class UpBlockFlat(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlockFlat( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +# Copied from diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim +class CrossAttnUpBlockFlat(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + attention_type="default", + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.attention_type = attention_type + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlockFlat( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def set_attention_slice(self, slice_size): + if slice_size is not None and self.attn_num_head_channels % slice_size != 0: + raise ValueError( + f"Make sure slice_size {slice_size} is a divisor of " + f"the number of heads used in cross_attention {self.attn_num_head_channels}" + ) + if slice_size is not None and slice_size > self.attn_num_head_channels: + raise ValueError( + f"Chunk_size {slice_size} has to be smaller or equal to " + f"the number of heads used in cross_attention {self.attn_num_head_channels}" + ) + + for attn in self.attentions: + attn._set_attention_slice(slice_size) + + self.gradient_checkpointing = False + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for attn in self.attentions: + attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat +class UNetMidBlockFlatCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + attention_type="default", + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + **kwargs, + ): + super().__init__() + + self.attention_type = attention_type + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def set_attention_slice(self, slice_size): + if slice_size is not None and self.attn_num_head_channels % slice_size != 0: + raise ValueError( + f"Make sure slice_size {slice_size} is a divisor of " + f"the number of heads used in cross_attention {self.attn_num_head_channels}" + ) + if slice_size is not None and slice_size > self.attn_num_head_channels: + raise ValueError( + f"Chunk_size {slice_size} has to be smaller or equal to " + f"the number of heads used in cross_attention {self.attn_num_head_channels}" + ) + + for attn in self.attentions: + attn._set_attention_slice(slice_size) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for attn in self.attentions: + attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states).sample + hidden_states = resnet(hidden_states, temb) + + return hidden_states diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py new file mode 100644 index 000000000000..1280419c343d --- /dev/null +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py @@ -0,0 +1,462 @@ +import inspect +from typing import Callable, List, Optional, Union + +import torch + +import PIL.Image +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import logging +from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline +from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline +from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class VersatileDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionMegaSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + tokenizer: CLIPTokenizer + image_feature_extractor: CLIPFeatureExtractor + text_encoder: CLIPTextModel + image_encoder: CLIPVisionModel + image_unet: UNet2DConditionModel + text_unet: UNet2DConditionModel + vae: AutoencoderKL + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + + def __init__( + self, + tokenizer: CLIPTokenizer, + image_feature_extractor: CLIPFeatureExtractor, + text_encoder: CLIPTextModel, + image_encoder: CLIPVisionModel, + image_unet: UNet2DConditionModel, + text_unet: UNet2DConditionModel, + vae: AutoencoderKL, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + image_feature_extractor=image_feature_extractor, + text_encoder=text_encoder, + image_encoder=image_encoder, + image_unet=image_unet, + text_unet=text_unet, + vae=vae, + scheduler=scheduler, + ) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.image_unet.config.attention_head_dim // 2 + self.image_unet.set_attention_slice(slice_size) + self.text_unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + @torch.no_grad() + def image_variation( + self, + image: Union[torch.FloatTensor, PIL.Image.Image], + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`): + The image prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionPipeline + >>> import torch + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + + >>> # let's download an initial image + >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" + + >>> response = requests.get(url) + >>> image = Image.open(BytesIO(response.content)).convert("RGB") + + >>> pipe = VersatileDiffusionPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> image = pipe(image, generator=generator).images[0] + >>> image.save("./car_variation.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + expected_components = inspect.signature(VersatileDiffusionImageVariationPipeline.__init__).parameters.keys() + components = {name: component for name, component in self.components.items() if name in expected_components} + return VersatileDiffusionImageVariationPipeline(**components)( + image=image, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + ) + + @torch.no_grad() + def text_to_image( + self, + prompt: Union[str, List[str]], + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionPipeline + >>> import torch + + >>> pipe = VersatileDiffusionPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> image = pipe.text_to_image("an astronaut riding on a horse on mars", generator=generator).images[0] + >>> image.save("./astronaut.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + expected_components = inspect.signature(VersatileDiffusionTextToImagePipeline.__init__).parameters.keys() + components = {name: component for name, component in self.components.items() if name in expected_components} + temp_pipeline = VersatileDiffusionTextToImagePipeline(**components) + output = temp_pipeline( + prompt=prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + ) + # swap the attention blocks back to the original state + temp_pipeline._swap_unet_attention_blocks() + + return output + + @torch.no_grad() + def dual_guided( + self, + prompt: Union[PIL.Image.Image, List[PIL.Image.Image]], + image: Union[str, List[str]], + text_to_image_strength: float = 0.5, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionPipeline + >>> import torch + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + + >>> # let's download an initial image + >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" + + >>> response = requests.get(url) + >>> image = Image.open(BytesIO(response.content)).convert("RGB") + >>> text = "a red car in the sun" + + >>> pipe = VersatileDiffusionPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> text_to_image_strength = 0.75 + + >>> image = pipe.dual_guided( + ... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator + ... ).images[0] + >>> image.save("./car_variation.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.ImagePipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images. + """ + + expected_components = inspect.signature(VersatileDiffusionDualGuidedPipeline.__init__).parameters.keys() + components = {name: component for name, component in self.components.items() if name in expected_components} + temp_pipeline = VersatileDiffusionDualGuidedPipeline(**components) + output = temp_pipeline( + prompt=prompt, + image=image, + text_to_image_strength=text_to_image_strength, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + ) + temp_pipeline._revert_dual_attention() + + return output diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py new file mode 100644 index 000000000000..ad4e8b0d0aa2 --- /dev/null +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py @@ -0,0 +1,628 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.utils.checkpoint + +import PIL +from transformers import ( + CLIPFeatureExtractor, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention import DualTransformer2DModel, Transformer2DModel +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import is_accelerate_available, logging +from .modeling_text_unet import UNetFlatConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) Model to encode and decode images to and from latent representations. + bert ([`LDMBertModel`]): + Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture. + tokenizer (`transformers.BertTokenizer`): + Tokenizer of class + [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + tokenizer: CLIPTokenizer + image_feature_extractor: CLIPFeatureExtractor + text_encoder: CLIPTextModelWithProjection + image_encoder: CLIPVisionModelWithProjection + image_unet: UNet2DConditionModel + text_unet: UNetFlatConditionModel + vae: AutoencoderKL + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + + def __init__( + self, + tokenizer: CLIPTokenizer, + image_feature_extractor: CLIPFeatureExtractor, + text_encoder: CLIPTextModelWithProjection, + image_encoder: CLIPVisionModelWithProjection, + image_unet: UNet2DConditionModel, + text_unet: UNetFlatConditionModel, + vae: AutoencoderKL, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + ): + super().__init__() + self.register_modules( + tokenizer=tokenizer, + image_feature_extractor=image_feature_extractor, + text_encoder=text_encoder, + image_encoder=image_encoder, + image_unet=image_unet, + text_unet=text_unet, + vae=vae, + scheduler=scheduler, + ) + + if self.text_unet is not None and ( + "dual_cross_attention" not in self.image_unet.config or not self.image_unet.config.dual_cross_attention + ): + # if loading from a universal checkpoint rather than a saved dual-guided pipeline + self._convert_to_dual_attention() + + def remove_unused_weights(self): + self.register_modules(text_unet=None) + + def _convert_to_dual_attention(self): + """ + Replace image_unet's `Transformer2DModel` blocks with `DualTransformer2DModel` that contains transformer blocks + from both `image_unet` and `text_unet` + """ + for name, module in self.image_unet.named_modules(): + if isinstance(module, Transformer2DModel): + parent_name, index = name.rsplit(".", 1) + index = int(index) + + image_transformer = self.image_unet.get_submodule(parent_name)[index] + text_transformer = self.text_unet.get_submodule(parent_name)[index] + + config = image_transformer.config + dual_transformer = DualTransformer2DModel( + num_attention_heads=config.num_attention_heads, + attention_head_dim=config.attention_head_dim, + in_channels=config.in_channels, + num_layers=config.num_layers, + dropout=config.dropout, + norm_num_groups=config.norm_num_groups, + cross_attention_dim=config.cross_attention_dim, + attention_bias=config.attention_bias, + sample_size=config.sample_size, + num_vector_embeds=config.num_vector_embeds, + activation_fn=config.activation_fn, + num_embeds_ada_norm=config.num_embeds_ada_norm, + ) + dual_transformer.transformers[0] = image_transformer + dual_transformer.transformers[1] = text_transformer + + self.image_unet.get_submodule(parent_name)[index] = dual_transformer + self.image_unet.register_to_config(dual_cross_attention=True) + + def _revert_dual_attention(self): + """ + Revert the image_unet `DualTransformer2DModel` blocks back to `Transformer2DModel` with image_unet weights Call + this function if you reuse `image_unet` in another pipeline, e.g. `VersatileDiffusionPipeline` + """ + for name, module in self.image_unet.named_modules(): + if isinstance(module, DualTransformer2DModel): + parent_name, index = name.rsplit(".", 1) + index = int(index) + self.image_unet.get_submodule(parent_name)[index] = module.transformers[0] + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention with unet->image_unet + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.image_unet.set_use_memory_efficient_attention_xformers(True) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention with unet->image_unet + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.image_unet.set_use_memory_efficient_attention_xformers(False) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.image_unet.config.attention_head_dim // 2 + self.image_unet.set_attention_slice(slice_size) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.image_unet, self.text_unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device with unet->image_unet + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.image_unet, "_hf_hook"): + return self.device + for module in self.image_unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_text_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + """ + + def normalize_embeddings(encoder_output): + embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) + embeds_pooled = encoder_output.text_embeds + embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True) + return embeds + + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = normalize_embeddings(text_embeddings) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens = [""] * batch_size + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = normalize_embeddings(uncond_embeddings) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def _encode_image_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + """ + + def normalize_embeddings(encoder_output): + embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state) + embeds = self.image_encoder.visual_projection(embeds) + embeds_pooled = embeds[:, 0:1] + embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True) + return embeds + + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + image_input = self.image_feature_extractor(images=prompt, return_tensors="pt") + pixel_values = image_input.pixel_values.to(device).to(self.image_encoder.dtype) + image_embeddings = self.image_encoder(pixel_values) + image_embeddings = normalize_embeddings(image_embeddings) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size + uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt") + pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype) + uncond_embeddings = self.image_encoder(pixel_values) + uncond_embeddings = normalize_embeddings(uncond_embeddings) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and conditional embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([uncond_embeddings, image_embeddings]) + + return image_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, prompt, image, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, PIL.Image.Image) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` `PIL.Image` or `list` but is {type(prompt)}") + if not isinstance(image, str) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list): + raise ValueError(f"`image` has to be of type `str` `PIL.Image` or `list` but is {type(image)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // 8, width // 8) + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def set_transformer_params(self, mix_ratio: float = 0.5, condition_types: Tuple = ("text", "image")): + for name, module in self.image_unet.named_modules(): + if isinstance(module, DualTransformer2DModel): + module.mix_ratio = mix_ratio + + for i, type in enumerate(condition_types): + if type == "text": + module.condition_lengths[i] = self.text_encoder.config.max_position_embeddings + module.transformer_index_for_condition[i] = 1 # use the second (text) transformer + else: + module.condition_lengths[i] = 257 + module.transformer_index_for_condition[i] = 0 # use the first (image) transformer + + @torch.no_grad() + def __call__( + self, + prompt: Union[PIL.Image.Image, List[PIL.Image.Image]], + image: Union[str, List[str]], + text_to_image_strength: float = 0.5, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionDualGuidedPipeline + >>> import torch + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + + >>> # let's download an initial image + >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" + + >>> response = requests.get(url) + >>> image = Image.open(BytesIO(response.content)).convert("RGB") + >>> text = "a red car in the sun" + + >>> pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe.remove_unused_weights() + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> text_to_image_strength = 0.75 + + >>> image = pipe( + ... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator + ... ).images[0] + >>> image.save("./car_variation.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.ImagePipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, image, height, width, callback_steps) + + # 2. Define call parameters + prompt = [prompt] if not isinstance(prompt, list) else prompt + image = [image] if not isinstance(image, list) else image + batch_size = len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompts + text_embeddings = self._encode_text_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance) + image_embeddings = self._encode_image_prompt(image, device, num_images_per_prompt, do_classifier_free_guidance) + dual_prompt_embeddings = torch.cat([text_embeddings, image_embeddings], dim=1) + prompt_types = ("text", "image") + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.image_unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + dual_prompt_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Combine the attention blocks of the image and text UNets + self.set_transformer_params(text_to_image_strength, prompt_types) + + # 8. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=dual_prompt_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 9. Post-processing + image = self.decode_latents(latents) + + # 10. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py new file mode 100644 index 000000000000..652b7b735af4 --- /dev/null +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py @@ -0,0 +1,462 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import numpy as np +import torch +import torch.utils.checkpoint + +import PIL +from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import is_accelerate_available, logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) Model to encode and decode images to and from latent representations. + bert ([`LDMBertModel`]): + Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture. + tokenizer (`transformers.BertTokenizer`): + Tokenizer of class + [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + image_feature_extractor: CLIPFeatureExtractor + image_encoder: CLIPVisionModelWithProjection + image_unet: UNet2DConditionModel + vae: AutoencoderKL + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + + def __init__( + self, + image_feature_extractor: CLIPFeatureExtractor, + image_encoder: CLIPVisionModelWithProjection, + image_unet: UNet2DConditionModel, + vae: AutoencoderKL, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + ): + super().__init__() + self.register_modules( + image_feature_extractor=image_feature_extractor, + image_encoder=image_encoder, + image_unet=image_unet, + vae=vae, + scheduler=scheduler, + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention with unet->image_unet + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.image_unet.set_use_memory_efficient_attention_xformers(True) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention with unet->image_unet + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.image_unet.set_use_memory_efficient_attention_xformers(False) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.image_unet.config.attention_head_dim // 2 + self.image_unet.set_attention_slice(slice_size) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.image_unet, self.text_unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device with unet->image_unet + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.image_unet, "_hf_hook"): + return self.device + for module in self.image_unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + + def normalize_embeddings(encoder_output): + embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state) + embeds = self.image_encoder.visual_projection(embeds) + embeds_pooled = embeds[:, 0:1] + embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True) + return embeds + + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + image_input = self.image_feature_extractor(images=prompt, return_tensors="pt") + pixel_values = image_input.pixel_values.to(device).to(self.image_encoder.dtype) + image_embeddings = self.image_encoder(pixel_values) + image_embeddings = normalize_embeddings(image_embeddings) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_images: List[str] + if negative_prompt is None: + uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, PIL.Image.Image): + uncond_images = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_images = negative_prompt + + uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt") + pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype) + uncond_embeddings = self.image_encoder(pixel_values) + uncond_embeddings = normalize_embeddings(uncond_embeddings) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and conditional embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([uncond_embeddings, image_embeddings]) + + return image_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, image, height, width, callback_steps): + if not isinstance(image, PIL.Image.Image) and not isinstance(image, torch.Tensor): + raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(image)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // 8, width // 8) + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor], + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`): + The image prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionImageVariationPipeline + >>> import torch + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + + >>> # let's download an initial image + >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" + + >>> response = requests.get(url) + >>> image = Image.open(BytesIO(response.content)).convert("RGB") + + >>> pipe = VersatileDiffusionImageVariationPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> image = pipe(image, generator=generator).images[0] + >>> image.save("./car_variation.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs(image, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(image, PIL.Image.Image) else len(image) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + image_embeddings = self._encode_prompt( + image, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.image_unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py new file mode 100644 index 000000000000..d07d734a642f --- /dev/null +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py @@ -0,0 +1,514 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import torch +import torch.utils.checkpoint + +from transformers import CLIPFeatureExtractor, CLIPTextModelWithProjection, CLIPTokenizer + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention import Transformer2DModel +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import is_accelerate_available, logging +from .modeling_text_unet import UNetFlatConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) Model to encode and decode images to and from latent representations. + bert ([`LDMBertModel`]): + Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture. + tokenizer (`transformers.BertTokenizer`): + Tokenizer of class + [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + tokenizer: CLIPTokenizer + image_feature_extractor: CLIPFeatureExtractor + text_encoder: CLIPTextModelWithProjection + image_unet: UNet2DConditionModel + text_unet: UNetFlatConditionModel + vae: AutoencoderKL + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModelWithProjection, + image_unet: UNet2DConditionModel, + text_unet: UNetFlatConditionModel, + vae: AutoencoderKL, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + ): + super().__init__() + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + image_unet=image_unet, + text_unet=text_unet, + vae=vae, + scheduler=scheduler, + ) + + if self.text_unet is not None: + self._swap_unet_attention_blocks() + + def _swap_unet_attention_blocks(self): + """ + Swap the `Transformer2DModel` blocks between the image and text UNets + """ + for name, module in self.image_unet.named_modules(): + if isinstance(module, Transformer2DModel): + parent_name, index = name.rsplit(".", 1) + index = int(index) + self.image_unet.get_submodule(parent_name)[index], self.text_unet.get_submodule(parent_name)[index] = ( + self.text_unet.get_submodule(parent_name)[index], + self.image_unet.get_submodule(parent_name)[index], + ) + + def remove_unused_weights(self): + self.register_modules(text_unet=None) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention with unet->image_unet + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.image_unet.set_use_memory_efficient_attention_xformers(True) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention with unet->image_unet + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.image_unet.set_use_memory_efficient_attention_xformers(False) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.image_unet.config.attention_head_dim // 2 + self.image_unet.set_attention_slice(slice_size) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.image_unet, self.text_unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device with unet->image_unet + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.image_unet, "_hf_hook"): + return self.device + for module in self.image_unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + + def normalize_embeddings(encoder_output): + embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) + embeds_pooled = encoder_output.text_embeds + embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True) + return embeds + + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = normalize_embeddings(text_embeddings) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = normalize_embeddings(uncond_embeddings) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs(self, prompt, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // 8, width // 8) + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionTextToImagePipeline + >>> import torch + + >>> pipe = VersatileDiffusionTextToImagePipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe.remove_unused_weights() + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> image = pipe("an astronaut riding on a horse on mars", generator=generator).images[0] + >>> image.save("./astronaut.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.image_unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 9. Post-processing + image = self.decode_latents(latents) + + # 10. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index c184b6295de7..d255c174c743 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -154,6 +154,66 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class VersatileDiffusionDualGuidedPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class VersatileDiffusionImageVariationPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class VersatileDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class VersatileDiffusionTextToImagePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class VQDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/versatile_diffusion/__init__.py b/tests/pipelines/versatile_diffusion/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/versatile_diffusion/test_versatile_diffusion_dual_guided.py b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_dual_guided.py new file mode 100644 index 000000000000..9fb6ca522f5d --- /dev/null +++ b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_dual_guided.py @@ -0,0 +1,112 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile +import unittest + +import numpy as np +import torch + +from diffusers import VersatileDiffusionDualGuidedPipeline +from diffusers.utils.testing_utils import load_image, require_torch_gpu, slow, torch_device + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class VersatileDiffusionDualGuidedPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pass + + +@slow +@require_torch_gpu +class VersatileDiffusionDualGuidedPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_remove_unused_weights_save_load(self): + pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained("shi-labs/versatile-diffusion") + # remove text_unet + pipe.remove_unused_weights() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + second_prompt = load_image( + "https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg" + ) + + generator = torch.Generator(device=torch_device).manual_seed(0) + image = pipe( + prompt="first prompt", + image=second_prompt, + text_to_image_strength=0.75, + generator=generator, + guidance_scale=7.5, + num_inference_steps=2, + output_type="numpy", + ).images + + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained(tmpdirname) + + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator = generator.manual_seed(0) + new_image = pipe( + prompt="first prompt", + image=second_prompt, + text_to_image_strength=0.75, + generator=generator, + guidance_scale=7.5, + num_inference_steps=2, + output_type="numpy", + ).images + + assert np.abs(image - new_image).sum() < 1e-5, "Models don't have the same forward pass" + + def test_inference_dual_guided(self): + pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained("shi-labs/versatile-diffusion") + pipe.remove_unused_weights() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + first_prompt = "cyberpunk 2077" + second_prompt = load_image( + "https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg" + ) + generator = torch.Generator(device=torch_device).manual_seed(0) + image = pipe( + prompt=first_prompt, + image=second_prompt, + text_to_image_strength=0.75, + generator=generator, + guidance_scale=7.5, + num_inference_steps=50, + output_type="numpy", + ).images + + image_slice = image[0, 253:256, 253:256, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.014, 0.0112, 0.0136, 0.0145, 0.0107, 0.0113, 0.0272, 0.0215, 0.0216]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/versatile_diffusion/test_versatile_diffusion_image_variation.py b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_image_variation.py new file mode 100644 index 000000000000..4eddc271db52 --- /dev/null +++ b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_image_variation.py @@ -0,0 +1,58 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from diffusers import VersatileDiffusionImageVariationPipeline +from diffusers.utils.testing_utils import load_image, require_torch_gpu, slow, torch_device + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class VersatileDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pass + + +@slow +@require_torch_gpu +class VersatileDiffusionImageVariationPipelineIntegrationTests(unittest.TestCase): + def test_inference_image_variations(self): + pipe = VersatileDiffusionImageVariationPipeline.from_pretrained("shi-labs/versatile-diffusion") + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + image_prompt = load_image( + "https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg" + ) + generator = torch.Generator(device=torch_device).manual_seed(0) + image = pipe( + image=image_prompt, + generator=generator, + guidance_scale=7.5, + num_inference_steps=50, + output_type="numpy", + ).images + + image_slice = image[0, 253:256, 253:256, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.0113, 0.2241, 0.4024, 0.0839, 0.0871, 0.2725, 0.2581, 0.0, 0.1096]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/versatile_diffusion/test_versatile_diffusion_mega.py b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_mega.py new file mode 100644 index 000000000000..1209abf6a8db --- /dev/null +++ b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_mega.py @@ -0,0 +1,129 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile +import unittest + +import numpy as np +import torch + +from diffusers import VersatileDiffusionPipeline +from diffusers.utils.testing_utils import load_image, require_torch_gpu, slow, torch_device + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class VersatileDiffusionMegaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pass + + +@slow +@require_torch_gpu +class VersatileDiffusionMegaPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_from_pretrained_save_pretrained(self): + pipe = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion", torch_dtype=torch.float16) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt_image = load_image( + "https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg" + ) + + generator = torch.Generator(device=torch_device).manual_seed(0) + image = pipe.dual_guided( + prompt="first prompt", + image=prompt_image, + text_to_image_strength=0.75, + generator=generator, + guidance_scale=7.5, + num_inference_steps=2, + output_type="numpy", + ).images + + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + pipe = VersatileDiffusionPipeline.from_pretrained(tmpdirname, torch_dtype=torch.float16) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator = generator.manual_seed(0) + new_image = pipe.dual_guided( + prompt="first prompt", + image=prompt_image, + text_to_image_strength=0.75, + generator=generator, + guidance_scale=7.5, + num_inference_steps=2, + output_type="numpy", + ).images + + assert np.abs(image - new_image).sum() < 1e-5, "Models don't have the same forward pass" + + def test_inference_dual_guided_then_text_to_image(self): + pipe = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion", torch_dtype=torch.float16) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt = "cyberpunk 2077" + init_image = load_image( + "https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg" + ) + generator = torch.Generator(device=torch_device).manual_seed(0) + image = pipe.dual_guided( + prompt=prompt, + image=init_image, + text_to_image_strength=0.75, + generator=generator, + guidance_scale=7.5, + num_inference_steps=50, + output_type="numpy", + ).images + + image_slice = image[0, 253:256, 253:256, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.014, 0.0112, 0.0136, 0.0145, 0.0107, 0.0113, 0.0272, 0.0215, 0.0216]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + prompt = "A painting of a squirrel eating a burger " + generator = torch.Generator(device=torch_device).manual_seed(0) + image = pipe.text_to_image( + prompt=prompt, generator=generator, guidance_scale=7.5, num_inference_steps=50, output_type="numpy" + ).images + + image_slice = image[0, 253:256, 253:256, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.0408, 0.0181, 0.0, 0.0388, 0.0046, 0.0461, 0.0411, 0.0, 0.0222]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + pipe = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion", torch_dtype=torch.float16) + image = pipe.image_variation(init_image, generator=generator, output_type="numpy").images[0] + + image_slice = image[0, 253:256, 253:256, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.0657, 0.0529, 0.0455, 0.0802, 0.0570, 0.0179, 0.0267, 0.0483, 0.0769]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/versatile_diffusion/test_versatile_diffusion_text_to_image.py b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_text_to_image.py new file mode 100644 index 000000000000..027819efee9f --- /dev/null +++ b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_text_to_image.py @@ -0,0 +1,86 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile +import unittest + +import numpy as np +import torch + +from diffusers import VersatileDiffusionTextToImagePipeline +from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class VersatileDiffusionTextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pass + + +@slow +@require_torch_gpu +class VersatileDiffusionTextToImagePipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_remove_unused_weights_save_load(self): + pipe = VersatileDiffusionTextToImagePipeline.from_pretrained("shi-labs/versatile-diffusion") + # remove text_unet + pipe.remove_unused_weights() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger " + generator = torch.Generator(device=torch_device).manual_seed(0) + image = pipe( + prompt=prompt, generator=generator, guidance_scale=7.5, num_inference_steps=2, output_type="numpy" + ).images + + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + pipe = VersatileDiffusionTextToImagePipeline.from_pretrained(tmpdirname) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator = generator.manual_seed(0) + new_image = pipe( + prompt=prompt, generator=generator, guidance_scale=7.5, num_inference_steps=2, output_type="numpy" + ).images + + assert np.abs(image - new_image).sum() < 1e-5, "Models don't have the same forward pass" + + def test_inference_text2img(self): + pipe = VersatileDiffusionTextToImagePipeline.from_pretrained("shi-labs/versatile-diffusion") + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger " + generator = torch.Generator(device=torch_device).manual_seed(0) + image = pipe( + prompt=prompt, generator=generator, guidance_scale=7.5, num_inference_steps=50, output_type="numpy" + ).images + + image_slice = image[0, 253:256, 253:256, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.0408, 0.0181, 0.0, 0.0388, 0.0046, 0.0461, 0.0411, 0.0, 0.0222]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2