diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index f7bce58d913b..d87c3e238de3 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -192,6 +192,8 @@ title: Stable unCLIP - local: api/pipelines/stochastic_karras_ve title: Stochastic Karras VE + - local: api/pipelines/text_to_video + title: Text-to-Video - local: api/pipelines/unclip title: UnCLIP - local: api/pipelines/latent_diffusion_uncond diff --git a/docs/source/en/api/models.mdx b/docs/source/en/api/models.mdx index dc425e98628c..572f8873ba12 100644 --- a/docs/source/en/api/models.mdx +++ b/docs/source/en/api/models.mdx @@ -37,6 +37,12 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module ## UNet2DConditionModel [[autodoc]] UNet2DConditionModel +## UNet3DConditionOutput +[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput + +## UNet3DConditionModel +[[autodoc]] UNet3DConditionModel + ## DecoderOutput [[autodoc]] models.vae.DecoderOutput @@ -58,6 +64,12 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module ## Transformer2DModelOutput [[autodoc]] models.transformer_2d.Transformer2DModelOutput +## TransformerTemporalModel +[[autodoc]] models.transformer_temporal.TransformerTemporalModel + +## Transformer2DModelOutput +[[autodoc]] models.transformer_temporal.TransformerTemporalModelOutput + ## PriorTransformer [[autodoc]] models.prior_transformer.PriorTransformer diff --git a/docs/source/en/api/pipelines/overview.mdx b/docs/source/en/api/pipelines/overview.mdx index 6d0a9a1159b2..3bf29888ae54 100644 --- a/docs/source/en/api/pipelines/overview.mdx +++ b/docs/source/en/api/pipelines/overview.mdx @@ -77,6 +77,7 @@ available a colab notebook to directly try them out. | [stable_unclip](./stable_unclip) | **Stable unCLIP** | Text-to-Image Generation | | [stable_unclip](./stable_unclip) | **Stable unCLIP** | Image-to-Image Text-Guided Generation | | [stochastic_karras_ve](./stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation | +| [text_to_video_sd](./api/pipelines/text_to_video) | [Modelscope's Text-to-video-synthesis Model in Open Domain](https://modelscope.cn/models/damo/text-to-video-synthesis/summary) | Text-to-Video Generation | | [unclip](./unclip) | [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://arxiv.org/abs/2204.06125) | Text-to-Image Generation | | [versatile_diffusion](./versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation | | [versatile_diffusion](./versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation | diff --git a/docs/source/en/api/pipelines/text_to_video.mdx b/docs/source/en/api/pipelines/text_to_video.mdx new file mode 100644 index 000000000000..f1fe794e1537 --- /dev/null +++ b/docs/source/en/api/pipelines/text_to_video.mdx @@ -0,0 +1,122 @@ + + +# Text-to-video synthesis + +Text-to-video synthesis from [ModelScope](https://modelscope.cn/) can be considered the same as Stable Diffusion structure-wise but it is extended to videos instead of static images. More specifically, this system allows us to generate videos from a natural language text prompt. + +From the [model summary](https://huggingface.co/damo-vilab/modelscope-damo-text-to-video-synthesis): + +*This model is based on a multi-stage text-to-video generation diffusion model, which inputs a description text and returns a video that matches the text description. Only English input is supported.* + +Resources: + +* [Website](https://modelscope.cn/models/damo/text-to-video-synthesis/summary) +* [GitHub repository](https://github.com/modelscope/modelscope/) +* [Spaces] (TODO) + +## Available Pipelines: + +| Pipeline | Tasks | Demo +|---|---|:---:| +| [DiffusionPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py) | *Text-to-Video Generation* | [Spaces] (TODO) + +## Usage example + +Let's start by generating a short video with the default length of 16 frames (2s at 8 fps): + +```python +import torch +from diffusers import DiffusionPipeline +from diffusers.utils import export_to_video + +pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16") +pipe = pipe.to("cuda") + +prompt = "Spiderman is surfing" +video_frames = pipe(prompt).frames +video_path = export_to_video(video_frames) +video_path +``` + +Diffusers supports different optimization techniques to improve the latency +and memory footprint of a pipeline. Since videos are often more memory-heavy than images, +we can enable CPU offloading and VAE slicing to keep the memory footprint at bay. + +Let's generate a video of 8 seconds (64 frames) on the same GPU using CPU offloading and VAE slicing: + +```python +import torch +from diffusers import DiffusionPipeline +from diffusers.utils import export_to_video + +pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16") +pipe.enable_model_cpu_offload() + +# memory optimization +pipe.enable_vae_slicing() + +prompt = "Darth Vader surfing a wave" +video_frames = pipe(prompt, num_frames=64).frames +video_path = export_to_video(video_frames) +video_path +``` + +It just takes **7 GBs of GPU memory** to generate the 64 video frames using PyTorch 2.0, "fp16" precision and the techniques mentioned above. + +We can also use a different scheduler easily, using the same method we'd use for Stable Diffusion: + +```python +import torch +from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler +from diffusers.utils import export_to_video + +pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16") +pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) +pipe.enable_model_cpu_offload() + +prompt = "Spiderman is surfing" +video_frames = pipe(prompt, num_inference_steps=25).frames +video_path = export_to_video(video_frames) +video_path +``` + +Here are some sample outputs: + + + + + + +
+ An astronaut riding a horse. +
+ An astronaut riding a horse. +
+ Darth vader surfing in waves. +
+ Darth vader surfing in waves. +
+ +## Available checkpoints + +* [damo-vilab/text-to-video-ms-1.7b](https://huggingface.co/damo-vilab/text-to-video-ms-1.7b/) +* [damo-vilab/text-to-video-ms-1.7b-legacy](https://huggingface.co/damo-vilab/text-to-video-ms-1.7b-legacy) + +## DiffusionPipeline +[[autodoc]] DiffusionPipeline + - all + - __call__ diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 59c4d595cc8b..2ccabb1b32ee 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -84,8 +84,9 @@ The library has three main components: | [stable_unclip](./stable_unclip) | Stable unCLIP | Text-to-Image Generation | | [stable_unclip](./stable_unclip) | Stable unCLIP | Image-to-Image Text-Guided Generation | | [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation | +| [text_to_video_sd](./api/pipelines/text_to_video) | [Modelscope's Text-to-video-synthesis Model in Open Domain](https://modelscope.cn/models/damo/text-to-video-synthesis/summary) | Text-to-Video Generation | | [unclip](./api/pipelines/unclip) | [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://arxiv.org/abs/2204.06125)(implementation by [kakaobrain](https://github.com/kakaobrain/karlo)) | Text-to-Image Generation | | [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation | | [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation | | [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation | -| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation | \ No newline at end of file +| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation | diff --git a/examples/community/stable_diffusion_controlnet_img2img.py b/examples/community/stable_diffusion_controlnet_img2img.py index 51533a92d84a..ec23564ae3cb 100644 --- a/examples/community/stable_diffusion_controlnet_img2img.py +++ b/examples/community/stable_diffusion_controlnet_img2img.py @@ -216,7 +216,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") diff --git a/examples/community/stable_diffusion_controlnet_inpaint.py b/examples/community/stable_diffusion_controlnet_inpaint.py index 02e71fb97ed1..b7c8a2a7a7f0 100644 --- a/examples/community/stable_diffusion_controlnet_inpaint.py +++ b/examples/community/stable_diffusion_controlnet_inpaint.py @@ -314,7 +314,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") diff --git a/examples/community/stable_diffusion_controlnet_inpaint_img2img.py b/examples/community/stable_diffusion_controlnet_inpaint_img2img.py index a7afe26fa91c..f435a3274f45 100644 --- a/examples/community/stable_diffusion_controlnet_inpaint_img2img.py +++ b/examples/community/stable_diffusion_controlnet_inpaint_img2img.py @@ -314,7 +314,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") diff --git a/scripts/convert_ms_text_to_video_to_diffusers.py b/scripts/convert_ms_text_to_video_to_diffusers.py new file mode 100644 index 000000000000..3102c7eede9b --- /dev/null +++ b/scripts/convert_ms_text_to_video_to_diffusers.py @@ -0,0 +1,428 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Conversion script for the LDM checkpoints. """ + +import argparse + +import torch + +from diffusers import UNet3DConditionModel + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + weight = old_checkpoint[path["old"]] + names = ["proj_attn.weight"] + names_2 = ["proj_out.weight", "proj_in.weight"] + if any(k in new_path for k in names): + checkpoint[new_path] = weight[:, :, 0] + elif any(k in new_path for k in names_2) and len(weight.shape) > 2 and ".attentions." not in new_path: + checkpoint[new_path] = weight[:, :, 0] + else: + checkpoint[new_path] = weight + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_temp_conv_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + mapping.append({"old": old_item, "new": old_item}) + + return mapping + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + if "temopral_conv" not in old_item: + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + unet_key = "model.diffusion_model." + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + print(f"Checkpoint {path} has both EMA and non-EMA weights.") + print( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + print( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + if config["class_embed_type"] is None: + # No parameters to port + ... + elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": + new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + else: + raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + first_temp_attention = [v for v in unet_state_dict if v.startswith("input_blocks.0.1")] + paths = renew_attention_paths(first_temp_attention) + meta_path = {"old": "input_blocks.0.1", "new": "transformer_in"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + temp_attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.2" in key] + + if f"input_blocks.{i}.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.op.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + temporal_convs = [key for key in resnets if "temopral_conv" in key] + paths = renew_temp_conv_paths(temporal_convs) + meta_path = { + "old": f"input_blocks.{i}.0.temopral_conv", + "new": f"down_blocks.{block_id}.temp_convs.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(temp_attentions): + paths = renew_attention_paths(temp_attentions) + meta_path = { + "old": f"input_blocks.{i}.2", + "new": f"down_blocks.{block_id}.temp_attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + temporal_convs_0 = [key for key in resnet_0 if "temopral_conv" in key] + attentions = middle_blocks[1] + temp_attentions = middle_blocks[2] + resnet_1 = middle_blocks[3] + temporal_convs_1 = [key for key in resnet_1 if "temopral_conv" in key] + + resnet_0_paths = renew_resnet_paths(resnet_0) + meta_path = {"old": "middle_block.0", "new": "mid_block.resnets.0"} + assign_to_checkpoint( + resnet_0_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path] + ) + + temp_conv_0_paths = renew_temp_conv_paths(temporal_convs_0) + meta_path = {"old": "middle_block.0.temopral_conv", "new": "mid_block.temp_convs.0"} + assign_to_checkpoint( + temp_conv_0_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path] + ) + + resnet_1_paths = renew_resnet_paths(resnet_1) + meta_path = {"old": "middle_block.3", "new": "mid_block.resnets.1"} + assign_to_checkpoint( + resnet_1_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path] + ) + + temp_conv_1_paths = renew_temp_conv_paths(temporal_convs_1) + meta_path = {"old": "middle_block.3.temopral_conv", "new": "mid_block.temp_convs.1"} + assign_to_checkpoint( + temp_conv_1_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path] + ) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + temp_attentions_paths = renew_attention_paths(temp_attentions) + meta_path = {"old": "middle_block.2", "new": "mid_block.temp_attentions.0"} + assign_to_checkpoint( + temp_attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + temp_attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.2" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + temporal_convs = [key for key in resnets if "temopral_conv" in key] + paths = renew_temp_conv_paths(temporal_convs) + meta_path = { + "old": f"output_blocks.{i}.0.temopral_conv", + "new": f"up_blocks.{block_id}.temp_convs.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(temp_attentions): + paths = renew_attention_paths(temp_attentions) + meta_path = { + "old": f"output_blocks.{i}.2", + "new": f"up_blocks.{block_id}.temp_attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + new_checkpoint[new_path] = unet_state_dict[old_path] + + temopral_conv_paths = [l for l in output_block_layers if "temopral_conv" in l] + for path in temopral_conv_paths: + pruned_path = path.split("temopral_conv.")[-1] + old_path = ".".join(["output_blocks", str(i), str(block_id), "temopral_conv", pruned_path]) + new_path = ".".join(["up_blocks", str(block_id), "temp_convs", str(layer_in_block_id), pruned_path]) + new_checkpoint[new_path] = unet_state_dict[old_path] + + return new_checkpoint + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + args = parser.parse_args() + + unet_checkpoint = torch.load(args.checkpoint_path, map_location="cpu") + unet = UNet3DConditionModel() + + converted_ckpt = convert_ldm_unet_checkpoint(unet_checkpoint, unet.config) + + diff_0 = set(unet.state_dict().keys()) - set(converted_ckpt.keys()) + diff_1 = set(converted_ckpt.keys()) - set(unet.state_dict().keys()) + + assert len(diff_0) == len(diff_1) == 0, "Converted weights don't match" + + # load state_dict + unet.load_state_dict(converted_ckpt) + + unet.save_pretrained(args.dump_path) + + # -- finish converting the unet -- diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f480b4100907..a1e736671be7 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -41,6 +41,7 @@ UNet1DModel, UNet2DConditionModel, UNet2DModel, + UNet3DConditionModel, VQModel, ) from .optimization import ( @@ -130,6 +131,7 @@ StableDiffusionUpscalePipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, + TextToVideoSDPipeline, UnCLIPImageVariationPipeline, UnCLIPPipeline, VersatileDiffusionDualGuidedPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index e0b2cddd4bf9..752aeb409f57 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -25,6 +25,7 @@ from .unet_1d import UNet1DModel from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel + from .unet_3d_condition import UNet3DConditionModel from .vq_model import VQModel if is_flax_available(): diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 5c7e54e7cd32..f271e00f8639 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -184,6 +184,10 @@ class BasicTransformerBlock(nn.Module): attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm (: obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. @@ -202,6 +206,7 @@ def __init__( num_embeds_ada_norm: Optional[int] = None, attention_bias: bool = False, only_cross_attention: bool = False, + double_self_attention: bool = False, upcast_attention: bool = False, norm_elementwise_affine: bool = True, norm_type: str = "layer_norm", @@ -233,10 +238,10 @@ def __init__( self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) # 2. Cross-Attn - if cross_attention_dim is not None: + if cross_attention_dim is not None or double_self_attention: self.attn2 = Attention( query_dim=dim, - cross_attention_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, @@ -253,7 +258,7 @@ def __init__( else: self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - if cross_attention_dim is not None: + if cross_attention_dim is not None or double_self_attention: # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during # the second cross attention block. diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index 9c0161065e4c..8f65c2357cac 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -207,6 +207,7 @@ def blend_h(self, a, b, blend_extent): def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: r"""Encode a batch of images using a tiled encoder. + Args: When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is: @@ -253,6 +254,7 @@ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> Autoen def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: r"""Decode a batch of images using a tiled decoder. + Args: When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is: diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index d159115d7ee3..98f8f19c896a 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -1,3 +1,18 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from functools import partial from typing import Optional @@ -764,3 +779,61 @@ def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)): out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 return out.view(-1, channel, out_h, out_w) + + +class TemporalConvLayer(nn.Module): + """ + Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from: + https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016 + """ + + def __init__(self, in_dim, out_dim=None, dropout=0.0): + super().__init__() + out_dim = out_dim or in_dim + self.in_dim = in_dim + self.out_dim = out_dim + + # conv layers + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)) + ) + self.conv2 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), + ) + self.conv3 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), + ) + self.conv4 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), + ) + + # zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.conv4[-1].weight) + nn.init.zeros_(self.conv4[-1].bias) + + def forward(self, hidden_states, num_frames=1): + hidden_states = ( + hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4) + ) + + identity = hidden_states + hidden_states = self.conv1(hidden_states) + hidden_states = self.conv2(hidden_states) + hidden_states = self.conv3(hidden_states) + hidden_states = self.conv4(hidden_states) + + hidden_states = identity + hidden_states + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape( + (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:] + ) + return hidden_states diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py new file mode 100644 index 000000000000..ece88b8db2d5 --- /dev/null +++ b/src/diffusers/models/transformer_temporal.py @@ -0,0 +1,176 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .attention import BasicTransformerBlock +from .modeling_utils import ModelMixin + + +@dataclass +class TransformerTemporalModelOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`) + Hidden states conditioned on `encoder_hidden_states` input. + """ + + sample: torch.FloatTensor + + +class TransformerTemporalModel(ModelMixin, ConfigMixin): + """ + Transformer model for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + double_self_attention (`bool`, *optional*): + Configure if each TransformerBlock should contain two self-attention layers + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + activation_fn: str = "geglu", + norm_elementwise_affine: bool = True, + double_self_attention: bool = True, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + double_self_attention=double_self_attention, + norm_elementwise_affine=norm_elementwise_affine, + ) + for d in range(num_layers) + ] + ) + + self.proj_out = nn.Linear(inner_dim, in_channels) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + num_frames=1, + cross_attention_kwargs=None, + return_dict: bool = True, + ): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels + conditioning. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.transformer_2d.TransformerTemporalModelOutput`] or `tuple`: + [`~models.transformer_2d.TransformerTemporalModelOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + """ + # 1. Input + batch_frames, channel, height, width = hidden_states.shape + batch_size = batch_frames // num_frames + + residual = hidden_states + + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) + + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states[None, None, :] + .reshape(batch_size, height, width, channel, num_frames) + .permute(0, 3, 4, 1, 2) + .contiguous() + ) + hidden_states = hidden_states.reshape(batch_frames, channel, height, width) + + output = hidden_states + residual + + if not return_dict: + return (output,) + + return TransformerTemporalModelOutput(sample=output) diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py new file mode 100644 index 000000000000..9f8ee2a22aab --- /dev/null +++ b/src/diffusers/models/unet_3d_blocks.py @@ -0,0 +1,670 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn + +from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D +from .transformer_2d import Transformer2DModel +from .transformer_temporal import TransformerTemporalModel + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=True, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", +): + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=True, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", +): + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=True, + upcast_attention=False, + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + temp_convs = [ + TemporalConvLayer( + in_channels, + in_channels, + dropout=0.1, + ) + ] + attentions = [] + temp_attentions = [] + + for _ in range(num_layers): + attentions.append( + Transformer2DModel( + in_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + in_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + in_channels, + in_channels, + dropout=0.1, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, + ): + hidden_states = self.resnets[0](hidden_states, temb) + hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames) + for attn, temp_attn, resnet, temp_conv in zip( + self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:] + ): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + temp_attentions = [] + temp_convs = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + ) + ) + attentions.append( + Transformer2DModel( + out_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + out_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, + ): + # TODO(Patrick, William) - attention mask is not used + output_states = () + + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + temp_convs = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, num_frames=1): + output_states = () + + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + temp_convs = [] + attentions = [] + temp_attentions = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + ) + ) + attentions.append( + Transformer2DModel( + out_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + out_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, + ): + # TODO(Patrick, William) - attention mask is not used + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + temp_convs = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1): + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py new file mode 100644 index 000000000000..de762acabebf --- /dev/null +++ b/src/diffusers/models/unet_3d_condition.py @@ -0,0 +1,492 @@ +# Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. +# Copyright 2023 The ModelScope Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, logging +from .embeddings import TimestepEmbedding, Timesteps +from .modeling_utils import ModelMixin +from .transformer_temporal import TransformerTemporalModel +from .unet_3d_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet3DConditionOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): + Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor + + +class UNet3DConditionModel(ModelMixin, ConfigMixin): + r""" + UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep + and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the models (such as downloading or saving, etc.) + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, it will skip the normalization and activation layers in post-processing + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + """ + + _supports_gradient_checkpointing = False + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1024, + attention_head_dim: Union[int, Tuple[int]] = 64, + ): + super().__init__() + + self.sample_size = sample_size + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_kernel = 3 + conv_out_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], True, 0) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + self.transformer_in = TransformerTemporalModel( + num_attention_heads=8, + attention_head_dim=attention_head_dim, + in_channels=block_out_channels[0], + num_layers=1, + ) + + # class embedding + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=False, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=False, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=False, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + self.conv_act = nn.SiLU() + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet3DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, num_frames, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet3DConditionOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Returns: + [`~models.unet_2d_condition.UNet3DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet3DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + num_frames = sample.shape[2] + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + emb = emb.repeat_interleave(repeats=num_frames, dim=0) + encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) + + # 2. pre-process + sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) + sample = self.conv_in(sample) + + sample = self.transformer_in(sample, num_frames=num_frames).sample + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) + + down_block_res_samples += res_samples + + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if mid_block_additional_residual is not None: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + num_frames=num_frames, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + + sample = self.conv_out(sample) + + # reshape to (batch, channel, framerate, width, height) + sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 5b6c729f80be..87d1a6997e59 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -65,6 +65,7 @@ StableUnCLIPPipeline, ) from .stable_diffusion_safe import StableDiffusionPipelineSafe + from .text_to_video_synthesis import TextToVideoSDPipeline from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline from .versatile_diffusion import ( VersatileDiffusionDualGuidedPipeline, diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index b94a2ec05649..1ae82beb54a4 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -234,7 +234,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 05138c86f246..b71217a4b3ec 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -244,7 +244,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index e977071b9c6c..76423867add1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -258,7 +258,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 5294fa4cfa06..81b2cfa9bc3e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -237,7 +237,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index fd82281005ad..aeb70b1b2234 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -274,7 +274,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 8b3a7944def1..835c88e19448 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -249,7 +249,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index b645ba667f77..cee7ace239db 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -293,7 +293,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index a770fb18aaa0..cb953a7803b2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -237,7 +237,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index 953df11aa4f7..06ab580d492f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -432,7 +432,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index f3db54caa342..2d40390b41d1 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -158,7 +158,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 7de12bd291fb..9c928129d0b9 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -394,7 +394,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") diff --git a/src/diffusers/pipelines/text_to_video_synthesis/__init__.py b/src/diffusers/pipelines/text_to_video_synthesis/__init__.py new file mode 100644 index 000000000000..c2437857a23a --- /dev/null +++ b/src/diffusers/pipelines/text_to_video_synthesis/__init__.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import torch + +from ...utils import BaseOutput, OptionalDependencyNotAvailable, is_torch_available, is_transformers_available + + +@dataclass +class TextToVideoSDPipelineOutput(BaseOutput): + """ + Output class for text to video pipelines. + + Args: + frames (`List[np.ndarray]` or `torch.FloatTensor`) + List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as + a `torch` tensor. NumPy array present the denoised images of the diffusion pipeline. The length of the list + denotes the video length i.e., the number of frames. + """ + + frames: Union[List[np.ndarray], torch.FloatTensor] + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 +else: + from .pipeline_text_to_video_synth import TextToVideoSDPipeline # noqa: F401 diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py new file mode 100644 index 000000000000..453809ef6df7 --- /dev/null +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -0,0 +1,668 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from ...models import AutoencoderKL, UNet3DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import TextToVideoSDPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import TextToVideoSDPipeline + >>> from diffusers.utils import export_to_video + + >>> pipe = TextToVideoSDPipeline.from_pretrained( + ... "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16" + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "Spiderman is surfing" + >>> video_frames = pipe(prompt).frames + >>> video_path = export_to_video(video_frames) + >>> video_path + ``` +""" + + +def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]: + # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 + # reshape to ncfhw + mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) + std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) + # unnormalize back to [0,1] + video = video.mul_(std).add_(mean) + video.clamp_(0, 1) + # prepare the final outputs + i, c, f, h, w = video.shape + images = video.permute(2, 3, 0, 4, 1).reshape( + f, h, i * w, c + ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c) + images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames) + images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c + return images + + +class TextToVideoSDPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Same as Stable Diffusion 2. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet3DConditionModel`]): Conditional U-Net architecture to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet3DConditionModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded + to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a + submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + image = self.vae.decode(latents).sample + video = ( + image[None, :] + .reshape( + ( + batch_size, + num_frames, + -1, + ) + + image.shape[2:] + ) + .permute(0, 2, 1, 3, 4) + ) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: int = 16, + num_inference_steps: int = 50, + guidance_scale: float = 9.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated video. + num_frames (`int`, *optional*, defaults to 16): + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + amounts to 2 seconds of video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, + usually at the expense of lower video quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generate video. Choose between `torch.FloatTensor` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated frames. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_images_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # reshape latents + bsz, channel, frames, width, height = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) + noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # reshape latents back + latents = latents[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + video_tensor = self.decode_latents(latents) + + if output_type == "pt": + video = video_tensor + else: + video = tensor2vid(video_tensor) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (video,) + + return TextToVideoSDPipelineOutput(frames=video) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 196b3b0279d0..d803b053be71 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -92,6 +92,8 @@ torch_device, ) +from .testing_utils import export_to_video + logger = get_logger(__name__) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index c731a1f1ddf3..700a3080fa11 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -122,6 +122,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class UNet3DConditionModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class VQModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 1b0f812ad16c..5a28ce8cb04e 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -347,6 +347,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class TextToVideoSDPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class UnCLIPImageVariationPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index b3c6d1824369..3c09cb24f965 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -169,6 +169,14 @@ if _onnx_available: logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}") +# (sayakpaul): importlib.util.find_spec("opencv-python") returns None even when it's installed. +# _opencv_available = importlib.util.find_spec("opencv-python") is not None +try: + _opencv_version = importlib_metadata.version("opencv-python") + _opencv_available = True + logger.debug(f"Successfully imported cv2 version {_opencv_version}") +except importlib_metadata.PackageNotFoundError: + _opencv_available = False _scipy_available = importlib.util.find_spec("scipy") is not None try: @@ -272,6 +280,10 @@ def is_onnx_available(): return _onnx_available +def is_opencv_available(): + return _opencv_available + + def is_scipy_available(): return _scipy_available @@ -332,6 +344,12 @@ def is_compel_available(): install onnxruntime` """ +# docstyle-ignore +OPENCV_IMPORT_ERROR = """ +{0} requires the OpenCV library but it was not found in your environment. You can install it with pip: `pip +install opencv-python` +""" + # docstyle-ignore SCIPY_IMPORT_ERROR = """ {0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install @@ -391,6 +409,7 @@ def is_compel_available(): ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)), + ("opencv", (is_opencv_available, OPENCV_IMPORT_ERROR)), ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)), diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index cea2869b3193..7a3b8029f828 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -3,12 +3,13 @@ import os import random import re +import tempfile import unittest import urllib.parse from distutils.util import strtobool from io import BytesIO, StringIO from pathlib import Path -from typing import Optional, Union +from typing import List, Optional, Union import numpy as np import PIL.Image @@ -16,7 +17,14 @@ import requests from packaging import version -from .import_utils import is_compel_available, is_flax_available, is_onnx_available, is_torch_available +from .import_utils import ( + BACKENDS_MAPPING, + is_compel_available, + is_flax_available, + is_onnx_available, + is_opencv_available, + is_torch_available, +) from .logging import get_logger @@ -253,6 +261,23 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: return image +def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str: + if is_opencv_available(): + import cv2 + else: + raise ImportError(BACKENDS_MAPPING["opencv"][1].format("export_to_video")) + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name + + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + h, w, c = video_frames[0].shape + video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=8, frameSize=(w, h)) + for i in range(len(video_frames)): + img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR) + video_writer.write(img) + return output_video_path + + def load_hf_numpy(path) -> np.ndarray: if not path.startswith("http://") or path.startswith("https://"): path = os.path.join( diff --git a/tests/models/test_models_unet_3d_condition.py b/tests/models/test_models_unet_3d_condition.py new file mode 100644 index 000000000000..a92b8edd5378 --- /dev/null +++ b/tests/models/test_models_unet_3d_condition.py @@ -0,0 +1,242 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from diffusers.models import ModelMixin, UNet3DConditionModel +from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.utils import ( + floats_tensor, + logging, + torch_device, +) +from diffusers.utils.import_utils import is_xformers_available + +from ..test_modeling_common import ModelTesterMixin + + +logger = logging.get_logger(__name__) +torch.backends.cuda.matmul.allow_tf32 = False + + +def create_lora_layers(model): + lora_attn_procs = {} + for name in model.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = model.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(model.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = model.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + lora_attn_procs[name] = lora_attn_procs[name].to(model.device) + + # add 1 to weights to mock trained weights + with torch.no_grad(): + lora_attn_procs[name].to_q_lora.up.weight += 1 + lora_attn_procs[name].to_k_lora.up.weight += 1 + lora_attn_procs[name].to_v_lora.up.weight += 1 + lora_attn_procs[name].to_out_lora.up.weight += 1 + + return lora_attn_procs + + +class UNet3DConditionModelTests(ModelTesterMixin, unittest.TestCase): + model_class = UNet3DConditionModel + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 4 + num_frames = 4 + sizes = (32, 32) + + noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) + time_step = torch.tensor([10]).to(torch_device) + encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device) + + return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} + + @property + def input_shape(self): + return (4, 4, 32, 32) + + @property + def output_shape(self): + return (4, 4, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_out_channels": (32, 64, 64, 64), + "down_block_types": ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + "up_block_types": ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"), + "cross_attention_dim": 32, + "attention_head_dim": 4, + "out_channels": 4, + "in_channels": 4, + "layers_per_block": 2, + "sample_size": 32, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_enable_works(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + + model.enable_xformers_memory_efficient_attention() + + assert ( + model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__ + == "XFormersAttnProcessor" + ), "xformers is not enabled" + + # Overriding because `block_out_channels` needs to be different for this model. + def test_forward_with_norm_groups(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["norm_num_groups"] = 32 + init_dict["block_out_channels"] = (32, 64, 64, 64) + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + # Overriding since the UNet3D outputs a different structure. + def test_determinism(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + # Warmup pass when using mps (see #372) + if torch_device == "mps" and isinstance(model, ModelMixin): + model(**self.dummy_input) + + first = model(**inputs_dict) + if isinstance(first, dict): + first = first.sample + + second = model(**inputs_dict) + if isinstance(second, dict): + second = second.sample + + out_1 = first.cpu().numpy() + out_2 = second.cpu().numpy() + out_1 = out_1[~np.isnan(out_1)] + out_2 = out_2[~np.isnan(out_2)] + max_diff = np.amax(np.abs(out_1 - out_2)) + self.assertLessEqual(max_diff, 1e-5) + + def test_model_attention_slicing(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = 8 + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + model.set_attention_slice("auto") + with torch.no_grad(): + output = model(**inputs_dict) + assert output is not None + + model.set_attention_slice("max") + with torch.no_grad(): + output = model(**inputs_dict) + assert output is not None + + model.set_attention_slice(2) + with torch.no_grad(): + output = model(**inputs_dict) + assert output is not None + + # (`attn_processors`) needs to be implemented in this model for this test. + # def test_lora_processors(self): + + # (`attn_processors`) needs to be implemented in this model for this test. + # def test_lora_save_load(self): + + # (`attn_processors`) needs to be implemented for this test in the model. + # def test_lora_save_load_safetensors(self): + + # (`attn_processors`) needs to be implemented for this test in the model. + # def test_lora_save_safetensors_load_torch(self): + + # (`attn_processors`) needs to be implemented for this test. + # def test_lora_save_torch_force_load_safetensors_error(self): + + # (`attn_processors`) needs to be added for this test. + # def test_lora_on_off(self): + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_lora_xformers_on_off(self): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = 4 + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + lora_attn_procs = create_lora_layers(model) + model.set_attn_processor(lora_attn_procs) + + # default + with torch.no_grad(): + sample = model(**inputs_dict).sample + + model.enable_xformers_memory_efficient_attention() + on_sample = model(**inputs_dict).sample + + model.disable_xformers_memory_efficient_attention() + off_sample = model(**inputs_dict).sample + + assert (sample - on_sample).abs().max() < 1e-4 + assert (sample - off_sample).abs().max() < 1e-4 + + +# (todo: sayakpaul) implement SLOW tests. diff --git a/tests/pipelines/text_to_video/__init__.py b/tests/pipelines/text_to_video/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/text_to_video/test_text_to_video.py b/tests/pipelines/text_to_video/test_text_to_video.py new file mode 100644 index 000000000000..eb43a360653a --- /dev/null +++ b/tests/pipelines/text_to_video/test_text_to_video.py @@ -0,0 +1,196 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + DPMSolverMultistepScheduler, + TextToVideoSDPipeline, + UNet3DConditionModel, +) +from diffusers.utils import load_numpy, skip_mps, slow + +from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = TextToVideoSDPipeline + params = TEXT_TO_IMAGE_PARAMS + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + # No `output_type`. + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback", + "callback_steps", + ] + ) + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet3DConditionModel( + block_out_channels=(32, 64, 64, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D"), + up_block_types=("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"), + cross_attention_dim=32, + attention_head_dim=4, + ) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + sample_size=128, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=512, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "pt", + } + return inputs + + def test_text_to_video_default_case(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = TextToVideoSDPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["output_type"] = "np" + frames = sd_pipe(**inputs).frames + image_slice = frames[0][-3:, -3:, -1] + + assert frames[0].shape == (64, 64, 3) + expected_slice = np.array([166, 184, 167, 118, 102, 123, 108, 93, 114]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_attention_slicing_forward_pass(self): + self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False) + + # (todo): sayakpaul + @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.") + def test_inference_batch_consistent(self): + pass + + # (todo): sayakpaul + @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.") + def test_inference_batch_single_identical(self): + pass + + @unittest.skip(reason="`num_images_per_prompt` argument is not supported for this pipeline.") + def test_num_images_per_prompt(self): + pass + + @skip_mps + def test_progress_bar(self): + return super().test_progress_bar() + + +@slow +class TextToVideoSDPipelineSlowTests(unittest.TestCase): + def test_full_model(self): + expected_video = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_to_video/video.npy" + ) + + pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b") + pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + pipe = pipe.to("cuda") + + prompt = "Spiderman is surfing" + generator = torch.Generator(device="cpu").manual_seed(0) + + video_frames = pipe(prompt, generator=generator, num_inference_steps=25, output_type="pt").frames + video = video_frames.cpu().numpy() + + assert np.abs(expected_video - video).mean() < 5e-2 + + def test_two_step_model(self): + expected_video = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_to_video/video_2step.npy" + ) + + pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b") + pipe = pipe.to("cuda") + + prompt = "Spiderman is surfing" + generator = torch.Generator(device="cpu").manual_seed(0) + + video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="pt").frames + video = video_frames.cpu().numpy() + + assert np.abs(expected_video - video).mean() < 5e-2 diff --git a/tests/test_pipelines_common.py b/tests/test_pipelines_common.py index 1ab6baeb81a3..ac2abd716e42 100644 --- a/tests/test_pipelines_common.py +++ b/tests/test_pipelines_common.py @@ -20,6 +20,13 @@ torch.backends.cuda.matmul.allow_tf32 = False +def to_np(tensor): + if isinstance(tensor, torch.Tensor): + tensor = tensor.detach().cpu().numpy() + + return tensor + + @require_torch class PipelineTesterMixin: """ @@ -130,7 +137,7 @@ def test_save_load_local(self): inputs = self.get_dummy_inputs(torch_device) output_loaded = pipe_loaded(**inputs)[0] - max_diff = np.abs(output - output_loaded).max() + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() self.assertLess(max_diff, 1e-4) def test_pipeline_call_signature(self): @@ -327,7 +334,7 @@ def test_dict_tuple_outputs_equivalent(self): output = pipe(**self.get_dummy_inputs(torch_device))[0] output_tuple = pipe(**self.get_dummy_inputs(torch_device), return_dict=False)[0] - max_diff = np.abs(output - output_tuple).max() + max_diff = np.abs(to_np(output) - to_np(output_tuple)).max() self.assertLess(max_diff, 1e-4) def test_components_function(self): @@ -351,7 +358,7 @@ def test_float16_inference(self): output = pipe(**self.get_dummy_inputs(torch_device))[0] output_fp16 = pipe_fp16(**self.get_dummy_inputs(torch_device))[0] - max_diff = np.abs(output - output_fp16).max() + max_diff = np.abs(to_np(output) - to_np(output_fp16)).max() self.assertLess(max_diff, 1e-2, "The outputs of the fp16 and fp32 pipelines are too different.") @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") @@ -383,7 +390,7 @@ def test_save_load_float16(self): inputs = self.get_dummy_inputs(torch_device) output_loaded = pipe_loaded(**inputs)[0] - max_diff = np.abs(output - output_loaded).max() + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() self.assertLess(max_diff, 1e-2, "The output of the fp16 pipeline changed after saving and loading.") def test_save_load_optional_components(self): @@ -421,7 +428,7 @@ def test_save_load_optional_components(self): inputs = self.get_dummy_inputs(torch_device) output_loaded = pipe_loaded(**inputs)[0] - max_diff = np.abs(output - output_loaded).max() + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() self.assertLess(max_diff, 1e-4) @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices") @@ -442,7 +449,7 @@ def test_to_device(self): self.assertTrue(all(device == "cuda" for device in model_devices)) output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0] - self.assertTrue(np.isnan(output_cuda).sum() == 0) + self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0) def test_to_dtype(self): components = self.get_dummy_components() @@ -482,7 +489,7 @@ def _test_attention_slicing_forward_pass( output_with_slicing = pipe(**inputs)[0] if test_max_difference: - max_diff = np.abs(output_with_slicing - output_without_slicing).max() + max_diff = np.abs(to_np(output_with_slicing) - to_np(output_without_slicing)).max() self.assertLess(max_diff, expected_max_diff, "Attention slicing should not affect the inference results") if test_mean_pixel_difference: @@ -508,7 +515,7 @@ def test_cpu_offload_forward_pass(self): inputs = self.get_dummy_inputs(torch_device) output_with_offload = pipe(**inputs)[0] - max_diff = np.abs(output_with_offload - output_without_offload).max() + max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max() self.assertLess(max_diff, 1e-4, "CPU offloading should not affect the inference results") @unittest.skipIf(