diff --git a/PHILOSOPHY.md b/PHILOSOPHY.md index 38c735480664..8baf11103d84 100644 --- a/PHILOSOPHY.md +++ b/PHILOSOPHY.md @@ -82,7 +82,7 @@ Models are designed as configurable toolboxes that are natural extensions of [Py The following design principles are followed: - Models correspond to **a type of model architecture**. *E.g.* the [`UNet2DConditionModel`] class is used for all UNet variations that expect 2D image inputs and are conditioned on some context. - All models can be found in [`src/diffusers/models`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and every model architecture shall be defined in its file, e.g. [`unet_2d_condition.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py), [`transformer_2d.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformer_2d.py), etc... -- Models **do not** follow the single-file policy and should make use of smaller model building blocks, such as [`attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py), [`resnet.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py), [`embeddings.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py), etc... **Note**: This is in stark contrast to Transformers' modelling files and shows that models do not really follow the single-file policy. +- Models **do not** follow the single-file policy and should make use of smaller model building blocks, such as [`attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py), [`resnet.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py), [`embeddings.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py), etc... **Note**: This is in stark contrast to Transformers' modeling files and shows that models do not really follow the single-file policy. - Models intend to expose complexity, just like PyTorch's `Module` class, and give clear error messages. - Models all inherit from `ModelMixin` and `ConfigMixin`. - Models can be optimized for performance when it doesn’t demand major code changes, keep backward compatibility, and give significant memory or compute gain. diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index e855ea36e8cf..a2b29e0dd570 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -94,6 +94,8 @@ title: Latent Consistency Model-LoRA - local: using-diffusers/inference_with_lcm title: Latent Consistency Model + - local: using-diffusers/svd + title: Stable Video Diffusion title: Specific pipeline examples - sections: - local: training/overview diff --git a/docs/source/en/using-diffusers/svd.md b/docs/source/en/using-diffusers/svd.md new file mode 100644 index 000000000000..5e009bfa088f --- /dev/null +++ b/docs/source/en/using-diffusers/svd.md @@ -0,0 +1,131 @@ + + +# Stable Video Diffusion + +[[open-in-colab]] + +[Stable Video Diffusion](https://static1.squarespace.com/static/6213c340453c3f502425776e/t/655ce779b9d47d342a93c890/1700587395994/stable_video_diffusion.pdf) is a powerful image-to-video generation model that can generate high resolution (576x1024) 2-4 second videos conditioned on the input image. + +This guide will show you how to use SVD to short generate videos from images. + +Before you begin, make sure you have the following libraries installed: + +```py +!pip install -q -U diffusers transformers accelerate +``` + +## Image to Video Generation + +The are two variants of SVD. [SVD](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid) +and [SVD-XT](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt). The svd checkpoint is trained to generate 14 frames and the svd-xt checkpoint is further +finetuned to generate 25 frames. + +We will use the `svd-xt` checkpoint for this guide. + +```python +import torch + +from diffusers import StableVideoDiffusionPipeline +from diffusers.utils import load_image, export_to_video + +pipe = StableVideoDiffusionPipeline.from_pretrained( + "stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16" +) +pipe.enable_model_cpu_offload() + +# Load the conditioning image +image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png?download=true") +image = image.resize((1024, 576)) + +generator = torch.manual_seed(42) +frames = pipe(image, decode_chunk_size=8, generator=generator).frames[0] + +export_to_video(frames, "generated.mp4", fps=7) +``` + + + + +Since generating videos is more memory intensive we can use the `decode_chunk_size` argument to control how many frames are decoded at once. This will reduce the memory usage. It's recommended to tweak this value based on your GPU memory. +Setting `decode_chunk_size=1` will decode one frame at a time and will use the least amount of memory but the video might have some flickering. + +Additionally, we also use [model cpu offloading](../../optimization/memory#model-offloading) to reduce the memory usage. + + + +### Torch.compile + +You can achieve a 20-25% speed-up at the expense of slightly increased memory by compiling the UNet as follows: + +```diff +- pipe.enable_model_cpu_offload() ++ pipe.to("cuda") ++ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) +``` + +### Low-memory + +Video generation is very memory intensive as we have to essentially generate `num_frames` all at once. The mechanism is very comparable to text-to-image generation with a high batch size. To reduce the memory requirement you have multiple options. The following options trade inference speed against lower memory requirement: +- enable model offloading: Each component of the pipeline is offloaded to CPU once it's not needed anymore. +- enable feed-forward chunking: The feed-forward layer runs in a loop instead of running with a single huge feed-forward batch size +- reduce `decode_chunk_size`: This means that the VAE decodes frames in chunks instead of decoding them all together. **Note**: In addition to leading to a small slowdown, this method also slightly leads to video quality deterioration + + You can enable them as follows: + ```diff +-pipe.enable_model_cpu_offload() +-frames = pipe(image, decode_chunk_size=8, generator=generator).frames[0] ++pipe.enable_model_cpu_offload() ++pipe.unet.enable_forward_chunking() ++frames = pipe(image, decode_chunk_size=2, generator=generator, num_frames=25).frames[0] +``` + + +Including all these tricks should lower the memory requirement to less than 8GB VRAM. + +### Micro-conditioning + +Along with conditioning image Stable Diffusion Video also allows providing micro-conditioning that allows more control over the generated video. +It accepts the following arguments: + +- `fps`: The frames per second of the generated video. +- `motion_bucket_id`: The motion bucket id to use for the generated video. This can be used to control the motion of the generated video. Increasing the motion bucket id will increase the motion of the generated video. +- `noise_aug_strength`: The amount of noise added to the conditioning image. The higher the values the less the video will resemble the conditioning image. Increasing this value will also increase the motion of the generated video. + +Here is an example of using micro-conditioning to generate a video with more motion. + +```python +import torch + +from diffusers import StableVideoDiffusionPipeline +from diffusers.utils import load_image, export_to_video + +pipe = StableVideoDiffusionPipeline.from_pretrained( + "stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16" +) +pipe.enable_model_cpu_offload() + +# Load the conditioning image +image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png?download=true") +image = image.resize((1024, 576)) + +generator = torch.manual_seed(42) +frames = pipe(image, decode_chunk_size=8, generator=generator, motion_bucket_id=180, noise_aug_strength=0.1).frames[0] +export_to_video(frames, "generated.mp4", fps=7) +``` + + + diff --git a/scripts/convert_svd_to_diffusers.py b/scripts/convert_svd_to_diffusers.py new file mode 100644 index 000000000000..3243ce294b26 --- /dev/null +++ b/scripts/convert_svd_to_diffusers.py @@ -0,0 +1,730 @@ +from diffusers.utils import is_accelerate_available, logging + + +if is_accelerate_available(): + pass + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + if controlnet: + unet_params = original_config.model.params.control_stage_config.params + else: + if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None: + unet_params = original_config.model.params.unet_config.params + else: + unet_params = original_config.model.params.network_config.params + + vae_params = original_config.model.params.first_stage_config.params.encoder_config.params + + block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = ( + "CrossAttnDownBlockSpatioTemporal" + if resolution in unet_params.attention_resolutions + else "DownBlockSpatioTemporal" + ) + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = ( + "CrossAttnUpBlockSpatioTemporal" + if resolution in unet_params.attention_resolutions + else "UpBlockSpatioTemporal" + ) + up_block_types.append(block_type) + resolution //= 2 + + if unet_params.transformer_depth is not None: + transformer_layers_per_block = ( + unet_params.transformer_depth + if isinstance(unet_params.transformer_depth, int) + else list(unet_params.transformer_depth) + ) + else: + transformer_layers_per_block = 1 + + vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) + + head_dim = unet_params.num_heads if "num_heads" in unet_params else None + use_linear_projection = ( + unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim_mult = unet_params.model_channels // unet_params.num_head_channels + head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)] + + class_embed_type = None + addition_embed_type = None + addition_time_embed_dim = None + projection_class_embeddings_input_dim = None + context_dim = None + + if unet_params.context_dim is not None: + context_dim = ( + unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0] + ) + + if "num_classes" in unet_params: + if unet_params.num_classes == "sequential": + addition_time_embed_dim = 256 + assert "adm_in_channels" in unet_params + projection_class_embeddings_input_dim = unet_params.adm_in_channels + + config = { + "sample_size": image_size // vae_scale_factor, + "in_channels": unet_params.in_channels, + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params.num_res_blocks, + "cross_attention_dim": context_dim, + "attention_head_dim": head_dim, + "use_linear_projection": use_linear_projection, + "class_embed_type": class_embed_type, + "addition_embed_type": addition_embed_type, + "addition_time_embed_dim": addition_time_embed_dim, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + "transformer_layers_per_block": transformer_layers_per_block, + } + + if "disable_self_attentions" in unet_params: + config["only_cross_attention"] = unet_params.disable_self_attentions + + if "num_classes" in unet_params and isinstance(unet_params.num_classes, int): + config["num_class_embeds"] = unet_params.num_classes + + if controlnet: + config["conditioning_channels"] = unet_params.hint_channels + else: + config["out_channels"] = unet_params.out_channels + config["up_block_types"] = tuple(up_block_types) + + return config + + +def assign_to_checkpoint( + paths, + checkpoint, + old_checkpoint, + attention_paths_to_split=None, + additional_replacements=None, + config=None, + mid_block_suffix="", +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + if mid_block_suffix is not None: + mid_block_suffix = f".{mid_block_suffix}" + else: + mid_block_suffix = "" + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", f"mid_block.resnets.0{mid_block_suffix}") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", f"mid_block.resnets.1{mid_block_suffix}") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + if new_path == "mid_block.resnets.0.spatial_res_block.norm1.weight": + print("yeyy") + + # proj_attn.weight has to be converted from conv 1D to linear + is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) + shape = old_checkpoint[path["old"]].shape + if is_attn_weight and len(shape) == 3: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + elif is_attn_weight and len(shape) == 4: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + new_item = new_item.replace("time_stack", "temporal_transformer_blocks") + + new_item = new_item.replace("time_pos_embed.0.bias", "time_pos_embed.linear_1.bias") + new_item = new_item.replace("time_pos_embed.0.weight", "time_pos_embed.linear_1.weight") + new_item = new_item.replace("time_pos_embed.2.bias", "time_pos_embed.linear_2.bias") + new_item = new_item.replace("time_pos_embed.2.weight", "time_pos_embed.linear_2.weight") + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = new_item.replace("time_stack.", "") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def convert_ldm_unet_checkpoint( + checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False +): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + if skip_extract_state_dict: + unet_state_dict = checkpoint + else: + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + unet_key = "model.diffusion_model." + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.") + logger.warning( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + logger.warning( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + if config["class_embed_type"] is None: + # No parameters to port + ... + elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": + new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + else: + raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") + + # if config["addition_embed_type"] == "text_time": + new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + spatial_resnets = [ + key + for key in input_blocks[i] + if f"input_blocks.{i}.0" in key + and ( + f"input_blocks.{i}.0.op" not in key + and f"input_blocks.{i}.0.time_stack" not in key + and f"input_blocks.{i}.0.time_mixer" not in key + ) + ] + temporal_resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0.time_stack" in key] + # import ipdb; ipdb.set_trace() + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(spatial_resnets) + meta_path = { + "old": f"input_blocks.{i}.0", + "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}.spatial_res_block", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + paths = renew_resnet_paths(temporal_resnets) + meta_path = { + "old": f"input_blocks.{i}.0", + "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}.temporal_res_block", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + # TODO resnet time_mixer.mix_factor + if f"input_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict: + new_checkpoint[ + f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor" + ] = unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + # import ipdb; ipdb.set_trace() + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_spatial = [key for key in resnet_0 if "time_stack" not in key and "time_mixer" not in key] + resnet_0_paths = renew_resnet_paths(resnet_0_spatial) + # import ipdb; ipdb.set_trace() + assign_to_checkpoint( + resnet_0_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="spatial_res_block" + ) + + resnet_0_temporal = [key for key in resnet_0 if "time_stack" in key and "time_mixer" not in key] + resnet_0_paths = renew_resnet_paths(resnet_0_temporal) + assign_to_checkpoint( + resnet_0_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="temporal_res_block" + ) + + resnet_1_spatial = [key for key in resnet_1 if "time_stack" not in key and "time_mixer" not in key] + resnet_1_paths = renew_resnet_paths(resnet_1_spatial) + assign_to_checkpoint( + resnet_1_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="spatial_res_block" + ) + + resnet_1_temporal = [key for key in resnet_1 if "time_stack" in key and "time_mixer" not in key] + resnet_1_paths = renew_resnet_paths(resnet_1_temporal) + assign_to_checkpoint( + resnet_1_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="temporal_res_block" + ) + + new_checkpoint["mid_block.resnets.0.time_mixer.mix_factor"] = unet_state_dict[ + "middle_block.0.time_mixer.mix_factor" + ] + new_checkpoint["mid_block.resnets.1.time_mixer.mix_factor"] = unet_state_dict[ + "middle_block.2.time_mixer.mix_factor" + ] + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + spatial_resnets = [ + key + for key in output_blocks[i] + if f"output_blocks.{i}.0" in key + and (f"output_blocks.{i}.0.time_stack" not in key and "time_mixer" not in key) + ] + # import ipdb; ipdb.set_trace() + + temporal_resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0.time_stack" in key] + + paths = renew_resnet_paths(spatial_resnets) + meta_path = { + "old": f"output_blocks.{i}.0", + "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}.spatial_res_block", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + paths = renew_resnet_paths(temporal_resnets) + meta_path = { + "old": f"output_blocks.{i}.0", + "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}.temporal_res_block", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if f"output_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict: + new_checkpoint[ + f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor" + ] = unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"] + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key and "conv" not in key] + if len(attentions): + paths = renew_attention_paths(attentions) + # import ipdb; ipdb.set_trace() + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + spatial_layers = [ + layer for layer in output_block_layers if "time_stack" not in layer and "time_mixer" not in layer + ] + resnet_0_paths = renew_resnet_paths(spatial_layers, n_shave_prefix_segments=1) + # import ipdb; ipdb.set_trace() + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join( + ["up_blocks", str(block_id), "resnets", str(layer_in_block_id), "spatial_res_block", path["new"]] + ) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + temporal_layers = [ + layer for layer in output_block_layers if "time_stack" in layer and "time_mixer" not in key + ] + resnet_0_paths = renew_resnet_paths(temporal_layers, n_shave_prefix_segments=1) + # import ipdb; ipdb.set_trace() + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join( + ["up_blocks", str(block_id), "resnets", str(layer_in_block_id), "temporal_res_block", path["new"]] + ) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + new_checkpoint["up_blocks.0.resnets.0.time_mixer.mix_factor"] = unet_state_dict[ + f"output_blocks.{str(i)}.0.time_mixer.mix_factor" + ] + + return new_checkpoint + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["to_q.weight", "to_k.weight", "to_v.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0, is_temporal=False): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # Temporal resnet + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = new_item.replace("time_stack.", "temporal_res_block.") + + # Spatial resnet + new_item = new_item.replace("conv1", "spatial_res_block.conv1") + new_item = new_item.replace("norm1", "spatial_res_block.norm1") + + new_item = new_item.replace("conv2", "spatial_res_block.conv2") + new_item = new_item.replace("norm2", "spatial_res_block.norm2") + + new_item = new_item.replace("nin_shortcut", "spatial_res_block.conv_shortcut") + + new_item = new_item.replace("mix_factor", "spatial_res_block.time_mixer.mix_factor") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") + + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") + + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + keys = list(checkpoint.keys()) + vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else "" + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + new_checkpoint["decoder.time_conv_out.weight"] = vae_state_dict["decoder.time_mix_conv.weight"] + new_checkpoint["decoder.time_conv_out.bias"] = vae_state_dict["decoder.time_mix_conv.bias"] + + # new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + # new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + # new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + # new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 8a0dc2b923d3..751c098097c7 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -76,6 +76,7 @@ [ "AsymmetricAutoencoderKL", "AutoencoderKL", + "AutoencoderKLTemporalDecoder", "AutoencoderTiny", "ConsistencyDecoderVAE", "ControlNetModel", @@ -92,6 +93,7 @@ "UNet2DModel", "UNet3DConditionModel", "UNetMotionModel", + "UNetSpatioTemporalConditionModel", "VQModel", ] ) @@ -277,6 +279,7 @@ "StableDiffusionXLPipeline", "StableUnCLIPImg2ImgPipeline", "StableUnCLIPPipeline", + "StableVideoDiffusionPipeline", "TextToVideoSDPipeline", "TextToVideoZeroPipeline", "UnCLIPImageVariationPipeline", @@ -446,6 +449,7 @@ from .models import ( AsymmetricAutoencoderKL, AutoencoderKL, + AutoencoderKLTemporalDecoder, AutoencoderTiny, ConsistencyDecoderVAE, ControlNetModel, @@ -462,6 +466,7 @@ UNet2DModel, UNet3DConditionModel, UNetMotionModel, + UNetSpatioTemporalConditionModel, VQModel, ) from .optimization import ( @@ -626,6 +631,7 @@ StableDiffusionXLPipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, + StableVideoDiffusionPipeline, TextToVideoSDPipeline, TextToVideoZeroPipeline, UnCLIPImageVariationPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index de2e2848b848..839045001bb0 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -14,7 +14,12 @@ from typing import TYPE_CHECKING -from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available +from ..utils import ( + DIFFUSERS_SLOW_IMPORT, + _LazyModule, + is_flax_available, + is_torch_available, +) _import_structure = {} @@ -23,6 +28,7 @@ _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] _import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoder_kl"] = ["AutoencoderKL"] + _import_structure["autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] _import_structure["autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["controlnet"] = ["ControlNetModel"] @@ -38,6 +44,7 @@ _import_structure["unet_3d_condition"] = ["UNet3DConditionModel"] _import_structure["unet_kandi3"] = ["Kandinsky3UNet"] _import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"] + _import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"] _import_structure["vq_model"] = ["VQModel"] if is_flax_available(): @@ -51,6 +58,7 @@ from .adapter import MultiAdapter, T2IAdapter from .autoencoder_asym_kl import AsymmetricAutoencoderKL from .autoencoder_kl import AutoencoderKL + from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_tiny import AutoencoderTiny from .consistency_decoder_vae import ConsistencyDecoderVAE from .controlnet import ControlNetModel @@ -66,6 +74,7 @@ from .unet_3d_condition import UNet3DConditionModel from .unet_kandi3 import Kandinsky3UNet from .unet_motion_model import MotionAdapter, UNetMotionModel + from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel from .vq_model import VQModel if is_flax_available(): diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 0c4c5de6e31a..f02b5e249eee 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -25,6 +25,31 @@ from .normalization import AdaLayerNorm, AdaLayerNormZero +def _chunked_feed_forward( + ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None +): + # "feed_forward_chunk_size" can be used to save memory + if hidden_states.shape[chunk_dim] % chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = hidden_states.shape[chunk_dim] // chunk_size + if lora_scale is None: + ff_output = torch.cat( + [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], + dim=chunk_dim, + ) + else: + # TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete + ff_output = torch.cat( + [ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], + dim=chunk_dim, + ) + + return ff_output + + @maybe_allow_in_graph class GatedSelfAttentionDense(nn.Module): r""" @@ -194,7 +219,12 @@ def __init__( if not self.use_ada_layer_norm_single: self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) - self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + ) # 4. Fuser if attention_type == "gated" or attention_type == "gated-text-image": @@ -208,7 +238,7 @@ def __init__( self._chunk_size = None self._chunk_dim = 0 - def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): # Sets chunk feed-forward self._chunk_size = chunk_size self._chunk_dim = dim @@ -311,18 +341,8 @@ def forward( if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory - if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: - raise ValueError( - f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." - ) - - num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size - ff_output = torch.cat( - [ - self.ff(hid_slice, scale=lora_scale) - for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim) - ], - dim=self._chunk_dim, + ff_output = _chunked_feed_forward( + self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale ) else: ff_output = self.ff(norm_hidden_states, scale=lora_scale) @@ -339,6 +359,137 @@ def forward( return hidden_states +@maybe_allow_in_graph +class TemporalBasicTransformerBlock(nn.Module): + r""" + A basic Transformer block for video like data. + + Parameters: + dim (`int`): The number of channels in the input and output. + time_mix_inner_dim (`int`): The number of channels for temporal attention. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + """ + + def __init__( + self, + dim: int, + time_mix_inner_dim: int, + num_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: Optional[int] = None, + ): + super().__init__() + self.is_res = dim == time_mix_inner_dim + + self.norm_in = nn.LayerNorm(dim) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.norm_in = nn.LayerNorm(dim) + self.ff_in = FeedForward( + dim, + dim_out=time_mix_inner_dim, + activation_fn="geglu", + ) + + self.norm1 = nn.LayerNorm(time_mix_inner_dim) + self.attn1 = Attention( + query_dim=time_mix_inner_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + cross_attention_dim=None, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = nn.LayerNorm(time_mix_inner_dim) + self.attn2 = Attention( + query_dim=time_mix_inner_dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(time_mix_inner_dim) + self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu") + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = None + + def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs): + # Sets chunk feed-forward + self._chunk_size = chunk_size + # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off + self._chunk_dim = 1 + + def forward( + self, + hidden_states: torch.FloatTensor, + num_frames: int, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + batch_frames, seq_length, channels = hidden_states.shape + batch_size = batch_frames // num_frames + + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels) + hidden_states = hidden_states.permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels) + + residual = hidden_states + hidden_states = self.norm_in(hidden_states) + + if self._chunk_size is not None: + hidden_states = _chunked_feed_forward(self.ff, hidden_states, self._chunk_dim, self._chunk_size) + else: + hidden_states = self.ff_in(hidden_states) + + if self.is_res: + hidden_states = hidden_states + residual + + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None) + hidden_states = attn_output + hidden_states + + # 3. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = self.norm2(hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self._chunk_size is not None: + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + if self.is_res: + hidden_states = ff_output + hidden_states + else: + hidden_states = ff_output + + hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels) + hidden_states = hidden_states.permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels) + + return hidden_states + + class FeedForward(nn.Module): r""" A feed-forward layer. diff --git a/src/diffusers/models/autoencoder_asym_kl.py b/src/diffusers/models/autoencoder_asym_kl.py index 818e181fcdf0..678e47234096 100644 --- a/src/diffusers/models/autoencoder_asym_kl.py +++ b/src/diffusers/models/autoencoder_asym_kl.py @@ -18,7 +18,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils.accelerate_utils import apply_forward_hook -from .autoencoder_kl import AutoencoderKLOutput +from .modeling_outputs import AutoencoderKLOutput from .modeling_utils import ModelMixin from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index 9003d982b32f..464bff9189dd 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from typing import Dict, Optional, Tuple, Union import torch @@ -19,7 +18,6 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import FromOriginalVAEMixin -from ..utils import BaseOutput from ..utils.accelerate_utils import apply_forward_hook from .attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -28,24 +26,11 @@ AttnAddedKVProcessor, AttnProcessor, ) +from .modeling_outputs import AutoencoderKLOutput from .modeling_utils import ModelMixin from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder -@dataclass -class AutoencoderKLOutput(BaseOutput): - """ - Output of AutoencoderKL encoding method. - - Args: - latent_dist (`DiagonalGaussianDistribution`): - Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. - `DiagonalGaussianDistribution` allows for sampling latents from the distribution. - """ - - latent_dist: "DiagonalGaussianDistribution" - - class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. diff --git a/src/diffusers/models/autoencoder_kl_temporal_decoder.py b/src/diffusers/models/autoencoder_kl_temporal_decoder.py new file mode 100644 index 000000000000..176b6e0df924 --- /dev/null +++ b/src/diffusers/models/autoencoder_kl_temporal_decoder.py @@ -0,0 +1,402 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..loaders import FromOriginalVAEMixin +from ..utils import is_torch_version +from ..utils.accelerate_utils import apply_forward_hook +from .attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor +from .modeling_outputs import AutoencoderKLOutput +from .modeling_utils import ModelMixin +from .unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder +from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder + + +class TemporalDecoder(nn.Module): + def __init__( + self, + in_channels: int = 4, + out_channels: int = 3, + block_out_channels: Tuple[int] = (128, 256, 512, 512), + layers_per_block: int = 2, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1) + self.mid_block = MidBlockTemporalDecoder( + num_layers=self.layers_per_block, + in_channels=block_out_channels[-1], + out_channels=block_out_channels[-1], + attention_head_dim=block_out_channels[-1], + ) + + # up + self.up_blocks = nn.ModuleList([]) + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i in range(len(block_out_channels)): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + up_block = UpBlockTemporalDecoder( + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + add_upsample=not is_final_block, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-6) + + self.conv_act = nn.SiLU() + self.conv_out = torch.nn.Conv2d( + in_channels=block_out_channels[0], + out_channels=out_channels, + kernel_size=3, + padding=1, + ) + + conv_out_kernel_size = (3, 1, 1) + padding = [int(k // 2) for k in conv_out_kernel_size] + self.time_conv_out = torch.nn.Conv3d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=conv_out_kernel_size, + padding=padding, + ) + + self.gradient_checkpointing = False + + def forward( + self, + sample: torch.FloatTensor, + image_only_indicator: torch.FloatTensor, + num_frames: int = 1, + ) -> torch.FloatTensor: + r"""The forward method of the `Decoder` class.""" + + sample = self.conv_in(sample) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), + sample, + image_only_indicator, + use_reentrant=False, + ) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), + sample, + image_only_indicator, + use_reentrant=False, + ) + else: + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), + sample, + image_only_indicator, + ) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), + sample, + image_only_indicator, + ) + else: + # middle + sample = self.mid_block(sample, image_only_indicator=image_only_indicator) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = up_block(sample, image_only_indicator=image_only_indicator) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + batch_frames, channels, height, width = sample.shape + batch_size = batch_frames // num_frames + sample = sample[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4) + sample = self.time_conv_out(sample) + + sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width) + + return sample + + +class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + layers_per_block: (`int`, *optional*, defaults to 1): Number of layers per block. + latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + scaling_factor (`float`, *optional*, defaults to 0.18215): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + force_upcast (`bool`, *optional*, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without loosing too much precision in which case + `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock2D",), + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + latent_channels: int = 4, + sample_size: int = 32, + scaling_factor: float = 0.18215, + force_upcast: float = True, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + double_z=True, + ) + + # pass init params to Decoder + self.decoder = TemporalDecoder( + in_channels=latent_channels, + out_channels=out_channels, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + ) + + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + + sample_size = ( + self.config.sample_size[0] + if isinstance(self.config.sample_size, (list, tuple)) + else self.config.sample_size + ) + self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor = 0.25 + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (Encoder, TemporalDecoder)): + module.gradient_checkpointing = value + + @property + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False + ): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor, _remove_lora=_remove_lora) + else: + module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor, _remove_lora=True) + + @apply_forward_hook + def encode( + self, x: torch.FloatTensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.FloatTensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + @apply_forward_hook + def decode( + self, + z: torch.FloatTensor, + num_frames: int, + return_dict: bool = True, + ) -> Union[DecoderOutput, torch.FloatTensor]: + """ + Decode a batch of images. + + Args: + z (`torch.FloatTensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + + """ + batch_size = z.shape[0] // num_frames + image_only_indicator = torch.zeros(batch_size, num_frames, dtype=z.dtype, device=z.device) + decoded = self.decoder(z, num_frames=num_frames, image_only_indicator=image_only_indicator) + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.FloatTensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + num_frames: int = 1, + ) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + + dec = self.decode(z, num_frames=num_frames).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/src/diffusers/models/modeling_outputs.py b/src/diffusers/models/modeling_outputs.py new file mode 100644 index 000000000000..8dfee5fec181 --- /dev/null +++ b/src/diffusers/models/modeling_outputs.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass + +from ..utils import BaseOutput + + +@dataclass +class AutoencoderKLOutput(BaseOutput): + """ + Output of AutoencoderKL encoding method. + + Args: + latent_dist (`DiagonalGaussianDistribution`): + Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. + `DiagonalGaussianDistribution` allows for sampling latents from the distribution. + """ + + latent_dist: "DiagonalGaussianDistribution" # noqa: F821 diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 7a48d343a531..970d2be05b7a 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -165,7 +165,10 @@ def __init__( self.Conv2d_0 = conv def forward( - self, hidden_states: torch.FloatTensor, output_size: Optional[int] = None, scale: float = 1.0 + self, + hidden_states: torch.FloatTensor, + output_size: Optional[int] = None, + scale: float = 1.0, ) -> torch.FloatTensor: assert hidden_states.shape[1] == self.channels @@ -379,7 +382,11 @@ def _upsample_2d( weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) inverse_conv = F.conv_transpose2d( - hidden_states, weight, stride=stride, output_padding=output_padding, padding=0 + hidden_states, + weight, + stride=stride, + output_padding=output_padding, + padding=0, ) output = upfirdn2d_native( @@ -530,7 +537,14 @@ def __init__(self, pad_mode: str = "reflect"): def forward(self, inputs: torch.Tensor) -> torch.Tensor: inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode) - weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) + weight = inputs.new_zeros( + [ + inputs.shape[1], + inputs.shape[1], + self.kernel.shape[0], + self.kernel.shape[1], + ] + ) indices = torch.arange(inputs.shape[1], device=inputs.device) kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1) weight[indices, indices] = kernel @@ -553,7 +567,14 @@ def __init__(self, pad_mode: str = "reflect"): def forward(self, inputs: torch.Tensor) -> torch.Tensor: inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode) - weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) + weight = inputs.new_zeros( + [ + inputs.shape[1], + inputs.shape[1], + self.kernel.shape[0], + self.kernel.shape[1], + ] + ) indices = torch.arange(inputs.shape[1], device=inputs.device) kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1) weight[indices, indices] = kernel @@ -690,11 +711,19 @@ def __init__( self.conv_shortcut = None if self.use_in_shortcut: self.conv_shortcut = conv_cls( - in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias + in_channels, + conv_2d_out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=conv_shortcut_bias, ) def forward( - self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, scale: float = 1.0 + self, + input_tensor: torch.FloatTensor, + temb: torch.FloatTensor, + scale: float = 1.0, ) -> torch.FloatTensor: hidden_states = input_tensor @@ -866,7 +895,10 @@ def forward(self, inputs: torch.Tensor, t: torch.Tensor) -> torch.Tensor: def upsample_2d( - hidden_states: torch.FloatTensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1 + hidden_states: torch.FloatTensor, + kernel: Optional[torch.FloatTensor] = None, + factor: int = 2, + gain: float = 1, ) -> torch.FloatTensor: r"""Upsample2D a batch of 2D images with the given filter. Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given @@ -910,7 +942,10 @@ def upsample_2d( def downsample_2d( - hidden_states: torch.FloatTensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1 + hidden_states: torch.FloatTensor, + kernel: Optional[torch.FloatTensor] = None, + factor: int = 2, + gain: float = 1, ) -> torch.FloatTensor: r"""Downsample2D a batch of 2D images with the given filter. Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the @@ -946,13 +981,20 @@ def downsample_2d( kernel = kernel * gain pad_value = kernel.shape[0] - factor output = upfirdn2d_native( - hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2) + hidden_states, + kernel.to(device=hidden_states.device), + down=factor, + pad=((pad_value + 1) // 2, pad_value // 2), ) return output def upfirdn2d_native( - tensor: torch.Tensor, kernel: torch.Tensor, up: int = 1, down: int = 1, pad: Tuple[int, int] = (0, 0) + tensor: torch.Tensor, + kernel: torch.Tensor, + up: int = 1, + down: int = 1, + pad: Tuple[int, int] = (0, 0), ) -> torch.Tensor: up_x = up_y = up down_x = down_y = down @@ -1008,7 +1050,13 @@ class TemporalConvLayer(nn.Module): dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. """ - def __init__(self, in_dim: int, out_dim: Optional[int] = None, dropout: float = 0.0, norm_num_groups: int = 32): + def __init__( + self, + in_dim: int, + out_dim: Optional[int] = None, + dropout: float = 0.0, + norm_num_groups: int = 32, + ): super().__init__() out_dim = out_dim or in_dim self.in_dim = in_dim @@ -1016,7 +1064,9 @@ def __init__(self, in_dim: int, out_dim: Optional[int] = None, dropout: float = # conv layers self.conv1 = nn.Sequential( - nn.GroupNorm(norm_num_groups, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)) + nn.GroupNorm(norm_num_groups, in_dim), + nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)), ) self.conv2 = nn.Sequential( nn.GroupNorm(norm_num_groups, out_dim), @@ -1058,3 +1108,261 @@ def forward(self, hidden_states: torch.Tensor, num_frames: int = 1) -> torch.Ten (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:] ) return hidden_states + + +class TemporalResnetBlock(nn.Module): + r""" + A Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv2d layer. If None, same as `in_channels`. + temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + """ + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + temb_channels: int = 512, + eps: float = 1e-6, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + kernel_size = (3, 1, 1) + padding = [k // 2 for k in kernel_size] + + self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=eps, affine=True) + self.conv1 = nn.Conv3d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=1, + padding=padding, + ) + + if temb_channels is not None: + self.time_emb_proj = nn.Linear(temb_channels, out_channels) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=eps, affine=True) + + self.dropout = torch.nn.Dropout(0.0) + self.conv2 = nn.Conv3d( + out_channels, + out_channels, + kernel_size=kernel_size, + stride=1, + padding=padding, + ) + + self.nonlinearity = get_activation("silu") + + self.use_in_shortcut = self.in_channels != out_channels + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = nn.Conv3d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor: + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + if self.time_emb_proj is not None: + temb = self.nonlinearity(temb) + temb = self.time_emb_proj(temb)[:, :, :, None, None] + temb = temb.permute(0, 2, 1, 3, 4) + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + +# VideoResBlock +class SpatioTemporalResBlock(nn.Module): + r""" + A SpatioTemporal Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv2d layer. If None, same as `in_channels`. + temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the spatial resenet. + temporal_eps (`float`, *optional*, defaults to `eps`): The epsilon to use for the temporal resnet. + merge_factor (`float`, *optional*, defaults to `0.5`): The merge factor to use for the temporal mixing. + merge_strategy (`str`, *optional*, defaults to `learned_with_images`): + The merge strategy to use for the temporal mixing. + switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`): + If `True`, switch the spatial and temporal mixing. + """ + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + temb_channels: int = 512, + eps: float = 1e-6, + temporal_eps: Optional[float] = None, + merge_factor: float = 0.5, + merge_strategy="learned_with_images", + switch_spatial_to_temporal_mix: bool = False, + ): + super().__init__() + + self.spatial_res_block = ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=eps, + ) + + self.temporal_res_block = TemporalResnetBlock( + in_channels=out_channels if out_channels is not None else in_channels, + out_channels=out_channels if out_channels is not None else in_channels, + temb_channels=temb_channels, + eps=temporal_eps if temporal_eps is not None else eps, + ) + + self.time_mixer = AlphaBlender( + alpha=merge_factor, + merge_strategy=merge_strategy, + switch_spatial_to_temporal_mix=switch_spatial_to_temporal_mix, + ) + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, + ): + num_frames = image_only_indicator.shape[-1] + hidden_states = self.spatial_res_block(hidden_states, temb) + + batch_frames, channels, height, width = hidden_states.shape + batch_size = batch_frames // num_frames + + hidden_states_mix = ( + hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4) + ) + hidden_states = ( + hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4) + ) + + if temb is not None: + temb = temb.reshape(batch_size, num_frames, -1) + + hidden_states = self.temporal_res_block(hidden_states, temb) + hidden_states = self.time_mixer( + x_spatial=hidden_states_mix, + x_temporal=hidden_states, + image_only_indicator=image_only_indicator, + ) + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width) + return hidden_states + + +class AlphaBlender(nn.Module): + r""" + A module to blend spatial and temporal features. + + Parameters: + alpha (`float`): The initial value of the blending factor. + merge_strategy (`str`, *optional*, defaults to `learned_with_images`): + The merge strategy to use for the temporal mixing. + switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`): + If `True`, switch the spatial and temporal mixing. + """ + + strategies = ["learned", "fixed", "learned_with_images"] + + def __init__( + self, + alpha: float, + merge_strategy: str = "learned_with_images", + switch_spatial_to_temporal_mix: bool = False, + ): + super().__init__() + self.merge_strategy = merge_strategy + self.switch_spatial_to_temporal_mix = switch_spatial_to_temporal_mix # For TemporalVAE + + if merge_strategy not in self.strategies: + raise ValueError(f"merge_strategy needs to be in {self.strategies}") + + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images": + self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))) + else: + raise ValueError(f"Unknown merge strategy {self.merge_strategy}") + + def get_alpha(self, image_only_indicator: torch.Tensor, ndims: int) -> torch.Tensor: + if self.merge_strategy == "fixed": + alpha = self.mix_factor + + elif self.merge_strategy == "learned": + alpha = torch.sigmoid(self.mix_factor) + + elif self.merge_strategy == "learned_with_images": + if image_only_indicator is None: + raise ValueError("Please provide image_only_indicator to use learned_with_images merge strategy") + + alpha = torch.where( + image_only_indicator.bool(), + torch.ones(1, 1, device=image_only_indicator.device), + torch.sigmoid(self.mix_factor)[..., None], + ) + + # (batch, channel, frames, height, width) + if ndims == 5: + alpha = alpha[:, None, :, None, None] + # (batch*frames, height*width, channels) + elif ndims == 3: + alpha = alpha.reshape(-1)[:, None, None] + else: + raise ValueError(f"Unexpected ndims {ndims}. Dimensions should be 3 or 5") + + else: + raise NotImplementedError + + return alpha + + def forward( + self, + x_spatial: torch.Tensor, + x_temporal: torch.Tensor, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + alpha = self.get_alpha(image_only_indicator, x_spatial.ndim) + alpha = alpha.to(x_spatial.dtype) + + if self.switch_spatial_to_temporal_mix: + alpha = 1.0 - alpha + + x = alpha * x_spatial + (1.0 - alpha) * x_temporal + return x diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py index 2e053d70eaa7..26e899a9b908 100644 --- a/src/diffusers/models/transformer_temporal.py +++ b/src/diffusers/models/transformer_temporal.py @@ -19,8 +19,10 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .attention import BasicTransformerBlock +from .attention import BasicTransformerBlock, TemporalBasicTransformerBlock +from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin +from .resnet import AlphaBlender @dataclass @@ -195,3 +197,183 @@ def forward( return (output,) return TransformerTemporalModelOutput(sample=output) + + +class TransformerSpatioTemporalModel(nn.Module): + """ + A Transformer model for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + out_channels (`int`, *optional*): + The number of channels in the output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + """ + + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: int = 320, + out_channels: Optional[int] = None, + num_layers: int = 1, + cross_attention_dim: Optional[int] = None, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + + inner_dim = num_attention_heads * attention_head_dim + self.inner_dim = inner_dim + + # 2. Define input layers + self.in_channels = in_channels + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6) + self.proj_in = nn.Linear(in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + ) + for d in range(num_layers) + ] + ) + + time_mix_inner_dim = inner_dim + self.temporal_transformer_blocks = nn.ModuleList( + [ + TemporalBasicTransformerBlock( + inner_dim, + time_mix_inner_dim, + num_attention_heads, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + ) + for _ in range(num_layers) + ] + ) + + time_embed_dim = in_channels * 4 + self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels) + self.time_proj = Timesteps(in_channels, True, 0) + self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images") + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + # TODO: should use out_channels for continuous projections + self.proj_out = nn.Linear(inner_dim, in_channels) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input hidden_states. + num_frames (`int`): + The number of frames to be processed per batch. This is used to reshape the hidden states. + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*): + A tensor indicating whether the input contains only images. 1 indicates that the input contains only + images, 0 indicates that the input contains video frames. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain + tuple. + + Returns: + [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: + If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is + returned, otherwise a `tuple` where the first element is the sample tensor. + """ + # 1. Input + batch_frames, _, height, width = hidden_states.shape + num_frames = image_only_indicator.shape[-1] + batch_size = batch_frames // num_frames + + time_context = encoder_hidden_states + time_context_first_timestep = time_context[None, :].reshape( + batch_size, num_frames, -1, time_context.shape[-1] + )[:, 0] + time_context = time_context_first_timestep[None, :].broadcast_to( + height * width, batch_size, 1, time_context.shape[-1] + ) + time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1]) + + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + + num_frames_emb = torch.arange(num_frames, device=hidden_states.device) + num_frames_emb = num_frames_emb.repeat(batch_size, 1) + num_frames_emb = num_frames_emb.reshape(-1) + t_emb = self.time_proj(num_frames_emb) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + + emb = self.time_pos_embed(t_emb) + emb = emb[:, None, :] + + # 2. Blocks + for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks): + if self.training and self.gradient_checkpointing: + hidden_states = torch.utils.checkpoint.checkpoint( + block, + hidden_states, + None, + encoder_hidden_states, + None, + use_reentrant=False, + ) + else: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + + hidden_states_mix = hidden_states + hidden_states_mix = hidden_states_mix + emb + + hidden_states_mix = temporal_block( + hidden_states_mix, + num_frames=num_frames, + encoder_hidden_states=time_context, + ) + hidden_states = self.time_mixer( + x_spatial=hidden_states, + x_temporal=hidden_states_mix, + image_only_indicator=image_only_indicator, + ) + + # 3. Output + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + + if not return_dict: + return (output,) + + return TransformerTemporalModelOutput(sample=output) diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index 767ab846d5dc..e9c505c347b0 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -19,10 +19,20 @@ from ..utils import is_torch_version from ..utils.torch_utils import apply_freeu +from .attention import Attention from .dual_transformer_2d import DualTransformer2DModel -from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D +from .resnet import ( + Downsample2D, + ResnetBlock2D, + SpatioTemporalResBlock, + TemporalConvLayer, + Upsample2D, +) from .transformer_2d import Transformer2DModel -from .transformer_temporal import TransformerTemporalModel +from .transformer_temporal import ( + TransformerSpatioTemporalModel, + TransformerTemporalModel, +) def get_down_block( @@ -45,7 +55,15 @@ def get_down_block( resnet_time_scale_shift: str = "default", temporal_num_attention_heads: int = 8, temporal_max_seq_length: int = 32, -) -> Union["DownBlock3D", "CrossAttnDownBlock3D", "DownBlockMotion", "CrossAttnDownBlockMotion"]: + transformer_layers_per_block: int = 1, +) -> Union[ + "DownBlock3D", + "CrossAttnDownBlock3D", + "DownBlockMotion", + "CrossAttnDownBlockMotion", + "DownBlockSpatioTemporal", + "CrossAttnDownBlockSpatioTemporal", +]: if down_block_type == "DownBlock3D": return DownBlock3D( num_layers=num_layers, @@ -118,6 +136,29 @@ def get_down_block( temporal_num_attention_heads=temporal_num_attention_heads, temporal_max_seq_length=temporal_max_seq_length, ) + elif down_block_type == "DownBlockSpatioTemporal": + # added for SDV + return DownBlockSpatioTemporal( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + ) + elif down_block_type == "CrossAttnDownBlockSpatioTemporal": + # added for SDV + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockSpatioTemporal") + return CrossAttnDownBlockSpatioTemporal( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + add_downsample=add_downsample, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + ) raise ValueError(f"{down_block_type} does not exist.") @@ -144,7 +185,16 @@ def get_up_block( temporal_num_attention_heads: int = 8, temporal_cross_attention_dim: Optional[int] = None, temporal_max_seq_length: int = 32, -) -> Union["UpBlock3D", "CrossAttnUpBlock3D", "UpBlockMotion", "CrossAttnUpBlockMotion"]: + transformer_layers_per_block: int = 1, + dropout: float = 0.0, +) -> Union[ + "UpBlock3D", + "CrossAttnUpBlock3D", + "UpBlockMotion", + "CrossAttnUpBlockMotion", + "UpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", +]: if up_block_type == "UpBlock3D": return UpBlock3D( num_layers=num_layers, @@ -221,6 +271,34 @@ def get_up_block( temporal_num_attention_heads=temporal_num_attention_heads, temporal_max_seq_length=temporal_max_seq_length, ) + elif up_block_type == "UpBlockSpatioTemporal": + # added for SDV + return UpBlockSpatioTemporal( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + add_upsample=add_upsample, + ) + elif up_block_type == "CrossAttnUpBlockSpatioTemporal": + # added for SDV + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockSpatioTemporal") + return CrossAttnUpBlockSpatioTemporal( + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + add_upsample=add_upsample, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + resolution_idx=resolution_idx, + ) + raise ValueError(f"{up_block_type} does not exist.") @@ -347,7 +425,10 @@ def forward( return_dict=False, )[0] hidden_states = temp_attn( - hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False + hidden_states, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, )[0] hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) @@ -443,7 +524,11 @@ def __init__( self.downsamplers = nn.ModuleList( [ Downsample2D( - out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", ) ] ) @@ -476,7 +561,10 @@ def forward( return_dict=False, )[0] hidden_states = temp_attn( - hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False + hidden_states, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, )[0] output_states += (hidden_states,) @@ -543,7 +631,11 @@ def __init__( self.downsamplers = nn.ModuleList( [ Downsample2D( - out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", ) ] ) @@ -553,7 +645,10 @@ def __init__( self.gradient_checkpointing = False def forward( - self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, num_frames: int = 1 + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + num_frames: int = 1, ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: output_states = () @@ -716,7 +811,10 @@ def forward( return_dict=False, )[0] hidden_states = temp_attn( - hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False + hidden_states, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, )[0] if self.upsamplers is not None: @@ -890,7 +988,11 @@ def __init__( self.downsamplers = nn.ModuleList( [ Downsample2D( - out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", ) ] ) @@ -920,14 +1022,20 @@ def custom_forward(*inputs): if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + create_custom_forward(resnet), + hidden_states, + temb, + use_reentrant=False, ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, scale ) hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, num_frames + create_custom_forward(motion_module), + hidden_states.requires_grad_(), + temb, + num_frames, ) else: @@ -1047,7 +1155,11 @@ def __init__( self.downsamplers = nn.ModuleList( [ Downsample2D( - out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", ) ] ) @@ -1442,7 +1554,10 @@ def custom_forward(*inputs): if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + create_custom_forward(resnet), + hidden_states, + temb, + use_reentrant=False, ) else: hidden_states = torch.utils.checkpoint.checkpoint( @@ -1636,3 +1751,645 @@ def custom_forward(*inputs): hidden_states = resnet(hidden_states, temb, scale=lora_scale) return hidden_states + + +class MidBlockTemporalDecoder(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + attention_head_dim: int = 512, + num_layers: int = 1, + upcast_attention: bool = False, + ): + super().__init__() + + resnets = [] + attentions = [] + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + resnets.append( + SpatioTemporalResBlock( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=None, + eps=1e-6, + temporal_eps=1e-5, + merge_factor=0.0, + merge_strategy="learned", + switch_spatial_to_temporal_mix=True, + ) + ) + + attentions.append( + Attention( + query_dim=in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + eps=1e-6, + upcast_attention=upcast_attention, + norm_num_groups=32, + bias=True, + residual_connection=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.FloatTensor, + image_only_indicator: torch.FloatTensor, + ): + hidden_states = self.resnets[0]( + hidden_states, + image_only_indicator=image_only_indicator, + ) + for resnet, attn in zip(self.resnets[1:], self.attentions): + hidden_states = attn(hidden_states) + hidden_states = resnet( + hidden_states, + image_only_indicator=image_only_indicator, + ) + + return hidden_states + + +class UpBlockTemporalDecoder(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: int = 1, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + SpatioTemporalResBlock( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=None, + eps=1e-6, + temporal_eps=1e-5, + merge_factor=0.0, + merge_strategy="learned", + switch_spatial_to_temporal_mix=True, + ) + ) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward( + self, + hidden_states: torch.FloatTensor, + image_only_indicator: torch.FloatTensor, + ) -> torch.FloatTensor: + for resnet in self.resnets: + hidden_states = resnet( + hidden_states, + image_only_indicator=image_only_indicator, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class UNetMidBlockSpatioTemporal(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + # there is always at least one resnet + resnets = [ + SpatioTemporalResBlock( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=1e-5, + ) + ] + attentions = [] + + for i in range(num_layers): + attentions.append( + TransformerSpatioTemporalModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + ) + ) + + resnets.append( + SpatioTemporalResBlock( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=1e-5, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + hidden_states = self.resnets[0]( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if self.training and self.gradient_checkpointing: # TODO + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + image_only_indicator, + **ckpt_kwargs, + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + hidden_states = resnet( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + + return hidden_states + + +class DownBlockSpatioTemporal(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + num_layers: int = 1, + add_downsample: bool = True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + SpatioTemporalResBlock( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=1e-5, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + output_states = () + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + image_only_indicator, + use_reentrant=False, + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + image_only_indicator, + ) + else: + hidden_states = resnet( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnDownBlockSpatioTemporal(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + add_downsample: bool = True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + SpatioTemporalResBlock( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=1e-6, + ) + ) + attentions.append( + TransformerSpatioTemporalModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=1, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + output_states = () + + blocks = list(zip(self.resnets, self.attentions)) + for resnet, attn in blocks: + if self.training and self.gradient_checkpointing: # TODO + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + image_only_indicator, + **ckpt_kwargs, + ) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + else: + hidden_states = resnet( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class UpBlockSpatioTemporal(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + num_layers: int = 1, + resnet_eps: float = 1e-6, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + SpatioTemporalResBlock( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + image_only_indicator, + use_reentrant=False, + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + image_only_indicator, + ) + else: + hidden_states = resnet( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class CrossAttnUpBlockSpatioTemporal(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + SpatioTemporalResBlock( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + ) + ) + attentions.append( + TransformerSpatioTemporalModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: # TODO + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + image_only_indicator, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + else: + hidden_states = resnet( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states diff --git a/src/diffusers/models/unet_spatio_temporal_condition.py b/src/diffusers/models/unet_spatio_temporal_condition.py new file mode 100644 index 000000000000..8d0d3e61d879 --- /dev/null +++ b/src/diffusers/models/unet_spatio_temporal_condition.py @@ -0,0 +1,489 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..loaders import UNet2DConditionLoadersMixin +from ..utils import BaseOutput, logging +from .attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor +from .embeddings import TimestepEmbedding, Timesteps +from .modeling_utils import ModelMixin +from .unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNetSpatioTemporalConditionOutput(BaseOutput): + """ + The output of [`UNetSpatioTemporalConditionModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor = None + + +class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + addition_time_embed_dim: (`int`, defaults to 256): + Dimension to to encode the additional time ids. + projection_class_embeddings_input_dim (`int`, defaults to 768): + The dimension of the projection of encoded `added_time_ids`. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`], + [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`]. + num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`): + The number of attention heads. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 8, + out_channels: int = 4, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlockSpatioTemporal", + "CrossAttnDownBlockSpatioTemporal", + "CrossAttnDownBlockSpatioTemporal", + "DownBlockSpatioTemporal", + ), + up_block_types: Tuple[str] = ( + "UpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", + ), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + addition_time_embed_dim: int = 256, + projection_class_embeddings_input_dim: int = 768, + layers_per_block: Union[int, Tuple[int]] = 2, + cross_attention_dim: Union[int, Tuple[int]] = 1024, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20), + num_frames: int = 25, + ): + super().__init__() + + self.sample_size = sample_size + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + + # input + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=3, + padding=1, + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=1e-5, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + resnet_act_fn="silu", + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlockSpatioTemporal( + block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + transformer_layers_per_block=transformer_layers_per_block[-1], + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=1e-5, + resolution_idx=i, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + resnet_act_fn="silu", + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5) + self.conv_act = nn.SiLU() + + self.conv_out = nn.Conv2d( + block_out_channels[0], + out_channels, + kernel_size=3, + padding=1, + ) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors( + name: str, + module: torch.nn.Module, + processors: Dict[str, AttentionProcessor], + ): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + added_time_ids: torch.Tensor, + return_dict: bool = True, + ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]: + r""" + The [`UNetSpatioTemporalConditionModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`. + added_time_ids: (`torch.FloatTensor`): + The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal + embeddings and added to the time embeddings. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain + tuple. + Returns: + [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + batch_size, num_frames = sample.shape[:2] + timesteps = timesteps.expand(batch_size) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb) + + time_embeds = self.add_time_proj(added_time_ids.flatten()) + time_embeds = time_embeds.reshape((batch_size, -1)) + time_embeds = time_embeds.to(emb.dtype) + aug_emb = self.add_embedding(time_embeds) + emb = emb + aug_emb + + # Flatten the batch and frames dimensions + # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width] + sample = sample.flatten(0, 1) + # Repeat the embeddings num_video_frames times + # emb: [batch, channels] -> [batch * frames, channels] + emb = emb.repeat_interleave(num_frames, dim=0) + # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] + encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0) + + # 2. pre-process + sample = self.conv_in(sample) + + image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device) + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + ) + else: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + image_only_indicator=image_only_indicator, + ) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + ) + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + image_only_indicator=image_only_indicator, + ) + + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + # 7. Reshape back to original shape + sample = sample.reshape(batch_size, num_frames, *sample.shape[1:]) + + if not return_dict: + return (sample,) + + return UNetSpatioTemporalConditionOutput(sample=sample) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 0f849a66eaea..0049456e2187 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -22,7 +22,12 @@ from ..utils.torch_utils import randn_tensor from .activations import get_activation from .attention_processor import SpatialNorm -from .unet_2d_blocks import AutoencoderTinyBlock, UNetMidBlock2D, get_down_block, get_up_block +from .unet_2d_blocks import ( + AutoencoderTinyBlock, + UNetMidBlock2D, + get_down_block, + get_up_block, +) @dataclass @@ -274,7 +279,9 @@ def __init__( self.gradient_checkpointing = False def forward( - self, sample: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None + self, + sample: torch.FloatTensor, + latent_embeds: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: r"""The forward method of the `Decoder` class.""" @@ -292,14 +299,20 @@ def custom_forward(*inputs): if is_torch_version(">=", "1.11.0"): # middle sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False + create_custom_forward(self.mid_block), + sample, + latent_embeds, + use_reentrant=False, ) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False + create_custom_forward(up_block), + sample, + latent_embeds, + use_reentrant=False, ) else: # middle @@ -540,7 +553,10 @@ def custom_forward(*inputs): if is_torch_version(">=", "1.11.0"): # middle sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False + create_custom_forward(self.mid_block), + sample, + latent_embeds, + use_reentrant=False, ) sample = sample.to(upscale_dtype) @@ -548,7 +564,10 @@ def custom_forward(*inputs): if image is not None and mask is not None: masked_image = (1 - mask) * image im_x = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.condition_encoder), masked_image, mask, use_reentrant=False + create_custom_forward(self.condition_encoder), + masked_image, + mask, + use_reentrant=False, ) # up @@ -558,7 +577,10 @@ def custom_forward(*inputs): mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest") sample = sample * mask_ + sample_ * (1 - mask_) sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False + create_custom_forward(up_block), + sample, + latent_embeds, + use_reentrant=False, ) if image is not None and mask is not None: sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask) @@ -573,7 +595,9 @@ def custom_forward(*inputs): if image is not None and mask is not None: masked_image = (1 - mask) * image im_x = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.condition_encoder), masked_image, mask + create_custom_forward(self.condition_encoder), + masked_image, + mask, ) # up @@ -754,7 +778,10 @@ def __init__(self, parameters: torch.Tensor, deterministic: bool = False): def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: # make sure sample is on the same device as the parameters and has same dtype sample = randn_tensor( - self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype + self.mean.shape, + generator=generator, + device=self.parameters.device, + dtype=self.parameters.dtype, ) x = self.mean + self.std * sample return x @@ -764,7 +791,10 @@ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor: return torch.Tensor([0.0]) else: if other is None: - return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) else: return 0.5 * torch.sum( torch.pow(self.mean - other.mean, 2) / other.var @@ -779,7 +809,10 @@ def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch. if self.deterministic: return torch.Tensor([0.0]) logtwopi = np.log(2.0 * np.pi) - return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) def mode(self) -> torch.Tensor: return self.mean @@ -820,7 +853,16 @@ def __init__( if i == 0: layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1)) else: - layers.append(nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, stride=2, bias=False)) + layers.append( + nn.Conv2d( + num_channels, + num_channels, + kernel_size=3, + padding=1, + stride=2, + bias=False, + ) + ) for _ in range(num_block): layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn)) @@ -899,7 +941,15 @@ def __init__( layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor)) conv_out_channel = num_channels if not is_final_block else out_channels - layers.append(nn.Conv2d(num_channels, conv_out_channel, kernel_size=3, padding=1, bias=is_final_block)) + layers.append( + nn.Conv2d( + num_channels, + conv_out_channel, + kernel_size=3, + padding=1, + bias=is_final_block, + ) + ) self.layers = nn.Sequential(*layers) self.gradient_checkpointing = False diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 78c1b7c6285d..df78fbeefe05 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -17,7 +17,12 @@ # These modules contain pipelines from multiple libraries/frameworks _dummy_objects = {} -_import_structure = {"stable_diffusion": [], "stable_diffusion_xl": [], "latent_diffusion": [], "controlnet": []} +_import_structure = { + "controlnet": [], + "latent_diffusion": [], + "stable_diffusion": [], + "stable_diffusion_xl": [], +} try: if not is_torch_available(): @@ -39,7 +44,11 @@ _import_structure["dit"] = ["DiTPipeline"] _import_structure["latent_diffusion"].extend(["LDMSuperResolutionPipeline"]) _import_structure["latent_diffusion_uncond"] = ["LDMPipeline"] - _import_structure["pipeline_utils"] = ["AudioPipelineOutput", "DiffusionPipeline", "ImagePipelineOutput"] + _import_structure["pipeline_utils"] = [ + "AudioPipelineOutput", + "DiffusionPipeline", + "ImagePipelineOutput", + ] _import_structure["pndm"] = ["PNDMPipeline"] _import_structure["repaint"] = ["RePaintPipeline"] _import_structure["score_sde_ve"] = ["ScoreSdeVePipeline"] @@ -61,7 +70,10 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["alt_diffusion"] = ["AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline"] + _import_structure["alt_diffusion"] = [ + "AltDiffusionImg2ImgPipeline", + "AltDiffusionPipeline", + ] _import_structure["animatediff"] = ["AnimateDiffPipeline"] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ @@ -110,7 +122,10 @@ "KandinskyV22PriorEmb2EmbPipeline", "KandinskyV22PriorPipeline", ] - _import_structure["kandinsky3"] = ["Kandinsky3Img2ImgPipeline", "Kandinsky3Pipeline"] + _import_structure["kandinsky3"] = [ + "Kandinsky3Img2ImgPipeline", + "Kandinsky3Pipeline", + ] _import_structure["latent_consistency_models"] = [ "LatentConsistencyModelImg2ImgPipeline", "LatentConsistencyModelPipeline", @@ -150,6 +165,7 @@ ] ) _import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"] + _import_structure["stable_video_diffusion"] = ["StableVideoDiffusionPipeline"] _import_structure["stable_diffusion_xl"].extend( [ "StableDiffusionXLImg2ImgPipeline", @@ -158,7 +174,10 @@ "StableDiffusionXLPipeline", ] ) - _import_structure["t2i_adapter"] = ["StableDiffusionAdapterPipeline", "StableDiffusionXLAdapterPipeline"] + _import_structure["t2i_adapter"] = [ + "StableDiffusionAdapterPipeline", + "StableDiffusionXLAdapterPipeline", + ] _import_structure["text_to_video_synthesis"] = [ "TextToVideoSDPipeline", "TextToVideoZeroPipeline", @@ -215,7 +234,9 @@ if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from ..utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 + from ..utils import ( + dummy_torch_and_transformers_and_k_diffusion_objects, + ) _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects)) else: @@ -258,7 +279,10 @@ _dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects)) else: - _import_structure["spectrogram_diffusion"] = ["MidiProcessor", "SpectrogramDiffusionPipeline"] + _import_structure["spectrogram_diffusion"] = [ + "MidiProcessor", + "SpectrogramDiffusionPipeline", + ] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -268,7 +292,11 @@ from ..utils.dummy_pt_objects import * # noqa F403 else: - from .auto_pipeline import AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image + from .auto_pipeline import ( + AutoPipelineForImage2Image, + AutoPipelineForInpainting, + AutoPipelineForText2Image, + ) from .consistency_models import ConsistencyModelPipeline from .dance_diffusion import DanceDiffusionPipeline from .ddim import DDIMPipeline @@ -276,7 +304,11 @@ from .dit import DiTPipeline from .latent_diffusion import LDMSuperResolutionPipeline from .latent_diffusion_uncond import LDMPipeline - from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput + from .pipeline_utils import ( + AudioPipelineOutput, + DiffusionPipeline, + ImagePipelineOutput, + ) from .pndm import PNDMPipeline from .repaint import RePaintPipeline from .score_sde_ve import ScoreSdeVePipeline @@ -299,7 +331,11 @@ from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline from .animatediff import AnimateDiffPipeline from .audioldm import AudioLDMPipeline - from .audioldm2 import AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel + from .audioldm2 import ( + AudioLDM2Pipeline, + AudioLDM2ProjectionModel, + AudioLDM2UNet2DConditionModel, + ) from .blip_diffusion import BlipDiffusionPipeline from .controlnet import ( BlipDiffusionControlNetPipeline, @@ -343,7 +379,10 @@ Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline, ) - from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline + from .latent_consistency_models import ( + LatentConsistencyModelImg2ImgPipeline, + LatentConsistencyModelPipeline, + ) from .latent_diffusion import LDMTextToImagePipeline from .musicldm import MusicLDMPipeline from .paint_by_example import PaintByExamplePipeline @@ -382,7 +421,11 @@ StableDiffusionXLInstructPix2PixPipeline, StableDiffusionXLPipeline, ) - from .t2i_adapter import StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline + from .stable_video_diffusion import StableVideoDiffusionPipeline + from .t2i_adapter import ( + StableDiffusionAdapterPipeline, + StableDiffusionXLAdapterPipeline, + ) from .text_to_video_synthesis import ( TextToVideoSDPipeline, TextToVideoZeroPipeline, @@ -471,7 +514,10 @@ from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 else: - from .spectrogram_diffusion import MidiProcessor, SpectrogramDiffusionPipeline + from .spectrogram_diffusion import ( + MidiProcessor, + SpectrogramDiffusionPipeline, + ) else: import sys diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index fcdca9c9f08b..5706298a281a 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -55,7 +55,9 @@ if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline + from ...utils.dummy_torch_and_transformers_objects import ( + StableDiffusionImageVariationPipeline, + ) _dummy_objects.update({"StableDiffusionImageVariationPipeline": StableDiffusionImageVariationPipeline}) else: @@ -90,7 +92,9 @@ ): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from ...utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 + from ...utils import ( + dummy_torch_and_transformers_and_k_diffusion_objects, + ) _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects)) else: @@ -137,18 +141,32 @@ StableDiffusionPipelineOutput, StableDiffusionSafetyChecker, ) - from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline + from .pipeline_stable_diffusion_attend_and_excite import ( + StableDiffusionAttendAndExcitePipeline, + ) from .pipeline_stable_diffusion_gligen import StableDiffusionGLIGENPipeline - from .pipeline_stable_diffusion_gligen_text_image import StableDiffusionGLIGENTextImagePipeline + from .pipeline_stable_diffusion_gligen_text_image import ( + StableDiffusionGLIGENTextImagePipeline, + ) from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline - from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy - from .pipeline_stable_diffusion_instruct_pix2pix import StableDiffusionInstructPix2PixPipeline - from .pipeline_stable_diffusion_latent_upscale import StableDiffusionLatentUpscalePipeline + from .pipeline_stable_diffusion_inpaint_legacy import ( + StableDiffusionInpaintPipelineLegacy, + ) + from .pipeline_stable_diffusion_instruct_pix2pix import ( + StableDiffusionInstructPix2PixPipeline, + ) + from .pipeline_stable_diffusion_latent_upscale import ( + StableDiffusionLatentUpscalePipeline, + ) from .pipeline_stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline - from .pipeline_stable_diffusion_model_editing import StableDiffusionModelEditingPipeline + from .pipeline_stable_diffusion_model_editing import ( + StableDiffusionModelEditingPipeline, + ) from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline - from .pipeline_stable_diffusion_paradigms import StableDiffusionParadigmsPipeline + from .pipeline_stable_diffusion_paradigms import ( + StableDiffusionParadigmsPipeline, + ) from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline from .pipeline_stable_unclip import StableUnCLIPPipeline @@ -160,9 +178,13 @@ if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline + from ...utils.dummy_torch_and_transformers_objects import ( + StableDiffusionImageVariationPipeline, + ) else: - from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline + from .pipeline_stable_diffusion_image_variation import ( + StableDiffusionImageVariationPipeline, + ) try: if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.26.0")): @@ -174,9 +196,13 @@ StableDiffusionPix2PixZeroPipeline, ) else: - from .pipeline_stable_diffusion_depth2img import StableDiffusionDepth2ImgPipeline + from .pipeline_stable_diffusion_depth2img import ( + StableDiffusionDepth2ImgPipeline, + ) from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline - from .pipeline_stable_diffusion_pix2pix_zero import StableDiffusionPix2PixZeroPipeline + from .pipeline_stable_diffusion_pix2pix_zero import ( + StableDiffusionPix2PixZeroPipeline, + ) try: if not ( @@ -189,7 +215,9 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import * else: - from .pipeline_stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline + from .pipeline_stable_diffusion_k_diffusion import ( + StableDiffusionKDiffusionPipeline, + ) try: if not (is_transformers_available() and is_onnx_available()): @@ -197,11 +225,22 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_onnx_objects import * else: - from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline - from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline - from .pipeline_onnx_stable_diffusion_inpaint import OnnxStableDiffusionInpaintPipeline - from .pipeline_onnx_stable_diffusion_inpaint_legacy import OnnxStableDiffusionInpaintPipelineLegacy - from .pipeline_onnx_stable_diffusion_upscale import OnnxStableDiffusionUpscalePipeline + from .pipeline_onnx_stable_diffusion import ( + OnnxStableDiffusionPipeline, + StableDiffusionOnnxPipeline, + ) + from .pipeline_onnx_stable_diffusion_img2img import ( + OnnxStableDiffusionImg2ImgPipeline, + ) + from .pipeline_onnx_stable_diffusion_inpaint import ( + OnnxStableDiffusionInpaintPipeline, + ) + from .pipeline_onnx_stable_diffusion_inpaint_legacy import ( + OnnxStableDiffusionInpaintPipelineLegacy, + ) + from .pipeline_onnx_stable_diffusion_upscale import ( + OnnxStableDiffusionUpscalePipeline, + ) try: if not (is_transformers_available() and is_flax_available()): @@ -210,8 +249,12 @@ from ...utils.dummy_flax_objects import * else: from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline - from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline - from .pipeline_flax_stable_diffusion_inpaint import FlaxStableDiffusionInpaintPipeline + from .pipeline_flax_stable_diffusion_img2img import ( + FlaxStableDiffusionImg2ImgPipeline, + ) + from .pipeline_flax_stable_diffusion_inpaint import ( + FlaxStableDiffusionInpaintPipeline, + ) from .pipeline_output import FlaxStableDiffusionPipelineOutput from .safety_checker_flax import FlaxStableDiffusionSafetyChecker diff --git a/src/diffusers/pipelines/stable_video_diffusion/__init__.py b/src/diffusers/pipelines/stable_video_diffusion/__init__.py new file mode 100644 index 000000000000..3bd4dc78966e --- /dev/null +++ b/src/diffusers/pipelines/stable_video_diffusion/__init__.py @@ -0,0 +1,58 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + BaseOutput, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure.update( + { + "pipeline_stable_video_diffusion": [ + "StableVideoDiffusionPipeline", + "StableVideoDiffusionPipelineOutput", + ], + } + ) + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_stable_video_diffusion import ( + StableVideoDiffusionPipeline, + StableVideoDiffusionPipelineOutput, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py new file mode 100644 index 000000000000..a82f5379e71a --- /dev/null +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -0,0 +1,649 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel +from ...schedulers import EulerDiscreteScheduler +from ...utils import BaseOutput, logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") + return x[(...,) + (None,) * dims_to_append] + + +def tensor2vid(video: torch.Tensor, processor, output_type="np"): + # Based on: + # https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 + + batch_size, channels, num_frames, height, width = video.shape + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = processor.postprocess(batch_vid, output_type) + + outputs.append(batch_output) + + return outputs + + +@dataclass +class StableVideoDiffusionPipelineOutput(BaseOutput): + r""" + Output class for zero-shot text-to-video pipeline. + + Args: + frames (`[List[PIL.Image.Image]`, `np.ndarray`]): + List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, + num_channels)`. + """ + + frames: Union[List[PIL.Image.Image], np.ndarray] + + +class StableVideoDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline to generate video from an input image using Stable Video Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)). + unet ([`UNetSpatioTemporalConditionModel`]): + A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents. + scheduler ([`EulerDiscreteScheduler`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images. + """ + + model_cpu_offload_seq = "image_encoder->unet->vae" + _callback_tensor_inputs = ["latents"] + + def __init__( + self, + vae: AutoencoderKLTemporalDecoder, + image_encoder: CLIPVisionModelWithProjection, + unet: UNetSpatioTemporalConditionModel, + scheduler: EulerDiscreteScheduler, + feature_extractor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + vae=vae, + image_encoder=image_encoder, + unet=unet, + scheduler=scheduler, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.image_processor.pil_to_numpy(image) + image = self.image_processor.numpy_to_pt(image) + + # We normalize the image before resizing to match with the original implementation. + # Then we unnormalize it after resizing. + image = image * 2.0 - 1.0 + image = _resize_with_antialiasing(image, (224, 224)) + image = (image + 1.0) / 2.0 + + # Normalize the image with for CLIP input + image = self.feature_extractor( + images=image, + do_normalize=True, + do_center_crop=False, + do_resize=False, + do_rescale=False, + return_tensors="pt", + ).pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds + image_embeddings = image_embeddings.unsqueeze(1) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + negative_image_embeddings = torch.zeros_like(image_embeddings) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_image_embeddings, image_embeddings]) + + return image_embeddings + + def _encode_vae_image( + self, + image: torch.Tensor, + device, + num_videos_per_prompt, + do_classifier_free_guidance, + ): + image = image.to(device=device) + image_latents = self.vae.encode(image).latent_dist.mode() + + if do_classifier_free_guidance: + negative_image_latents = torch.zeros_like(image_latents) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_latents = torch.cat([negative_image_latents, image_latents]) + + # duplicate image_latents for each generation per prompt, using mps friendly method + image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1) + + return image_latents + + def _get_add_time_ids( + self, + fps, + motion_bucket_id, + noise_aug_strength, + dtype, + batch_size, + num_videos_per_prompt, + do_classifier_free_guidance, + ): + add_time_ids = [fps, motion_bucket_id, noise_aug_strength] + + passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1) + + if do_classifier_free_guidance: + add_time_ids = torch.cat([add_time_ids, add_time_ids]) + + return add_time_ids + + def decode_latents(self, latents, num_frames, decode_chunk_size=14): + # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width] + latents = latents.flatten(0, 1) + + latents = 1 / self.vae.config.scaling_factor * latents + + accepts_num_frames = "num_frames" in set(inspect.signature(self.vae.forward).parameters.keys()) + + # decode decode_chunk_size frames at a time to avoid OOM + frames = [] + for i in range(0, latents.shape[0], decode_chunk_size): + num_frames_in = latents[i : i + decode_chunk_size].shape[0] + decode_kwargs = {} + if accepts_num_frames: + # we only pass num_frames_in if it's expected + decode_kwargs["num_frames"] = num_frames_in + + frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample + frames.append(frame) + frames = torch.cat(frames, dim=0) + + # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width] + frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + frames = frames.float() + return frames + + def check_inputs(self, image, height, width): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + def prepare_latents( + self, + batch_size, + num_frames, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size, + num_frames, + num_channels_latents // 2, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + def __call__( + self, + image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], + height: int = 576, + width: int = 1024, + num_frames: Optional[int] = None, + num_inference_steps: int = 25, + min_guidance_scale: float = 1.0, + max_guidance_scale: float = 3.0, + fps: int = 7, + motion_bucket_id: int = 127, + noise_aug_strength: int = 0.02, + decode_chunk_size: Optional[int] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + return_dict: bool = True, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + Image or images to guide image generation. If you provide a tensor, it needs to be compatible with + [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json). + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_frames (`int`, *optional*): + The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt` + num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + min_guidance_scale (`float`, *optional*, defaults to 1.0): + The minimum guidance scale. Used for the classifier free guidance with first frame. + max_guidance_scale (`float`, *optional*, defaults to 3.0): + The maximum guidance scale. Used for the classifier free guidance with last frame. + fps (`int`, *optional*, defaults to 7): + Frames per second. The rate at which the generated images shall be exported to a video after generation. + Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training. + motion_bucket_id (`int`, *optional*, defaults to 127): + The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video. + noise_aug_strength (`int`, *optional*, defaults to 0.02): + The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion. + decode_chunk_size (`int`, *optional*): + The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency + between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once + for maximal quality. Reduce `decode_chunk_size` to reduce memory usage. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + + Returns: + [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list of list with the generated frames. + + Examples: + + ```py + from diffusers import StableVideoDiffusionPipeline + from diffusers.utils import load_image, export_to_video + + pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16") + pipe.to("cuda") + + image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200") + image = image.resize((1024, 576)) + + frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0] + export_to_video(frames, "generated.mp4", fps=7) + ``` + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_frames = num_frames if num_frames is not None else self.unet.config.num_frames + decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames + + # 1. Check inputs. Raise error if not correct + self.check_inputs(image, height, width) + + # 2. Define call parameters + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = max_guidance_scale > 1.0 + + # 3. Encode input image + image_embeddings = self._encode_image(image, device, num_videos_per_prompt, do_classifier_free_guidance) + + # NOTE: Stable Diffusion Video was conditioned on fps - 1, which + # is why it is reduced here. + # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188 + fps = fps - 1 + + # 4. Encode input image using VAE + image = self.image_processor.preprocess(image, height=height, width=width) + noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype) + image = image + noise_aug_strength * noise + + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + if needs_upcasting: + self.vae.to(dtype=torch.float32) + + image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, do_classifier_free_guidance) + image_latents = image_latents.to(image_embeddings.dtype) + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + + # Repeat the image latents for each frame so we can concatenate them with the noise + # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width] + image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1) + + # 5. Get Added Time IDs + added_time_ids = self._get_add_time_ids( + fps, + motion_bucket_id, + noise_aug_strength, + image_embeddings.dtype, + batch_size, + num_videos_per_prompt, + do_classifier_free_guidance, + ) + added_time_ids = added_time_ids.to(device) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_frames, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 7. Prepare guidance scale + guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0) + guidance_scale = guidance_scale.to(device, latents.dtype) + guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1) + guidance_scale = _append_dims(guidance_scale, latents.ndim) + + self._guidance_scale = guidance_scale + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # Concatenate image_latents over channels dimention + latent_model_input = torch.cat([latent_model_input, image_latents], dim=2) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=image_embeddings, + added_time_ids=added_time_ids, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + frames = self.decode_latents(latents, num_frames, decode_chunk_size) + frames = tensor2vid(frames, self.image_processor, output_type=output_type) + else: + frames = latents + + self.maybe_free_model_hooks() + + if not return_dict: + return frames + + return StableVideoDiffusionPipelineOutput(frames=frames) + + +# resizing utils +# TODO: clean up later +def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True): + h, w = input.shape[-2:] + factors = (h / size[0], w / size[1]) + + # First, we have to determine sigma + # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171 + sigmas = ( + max((factors[0] - 1.0) / 2.0, 0.001), + max((factors[1] - 1.0) / 2.0, 0.001), + ) + + # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma + # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206 + # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now + ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) + + # Make sure it is odd + if (ks[0] % 2) == 0: + ks = ks[0] + 1, ks[1] + + if (ks[1] % 2) == 0: + ks = ks[0], ks[1] + 1 + + input = _gaussian_blur2d(input, ks, sigmas) + + output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners) + return output + + +def _compute_padding(kernel_size): + """Compute padding tuple.""" + # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) + # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad + if len(kernel_size) < 2: + raise AssertionError(kernel_size) + computed = [k - 1 for k in kernel_size] + + # for even kernels we need to do asymmetric padding :( + out_padding = 2 * len(kernel_size) * [0] + + for i in range(len(kernel_size)): + computed_tmp = computed[-(i + 1)] + + pad_front = computed_tmp // 2 + pad_rear = computed_tmp - pad_front + + out_padding[2 * i + 0] = pad_front + out_padding[2 * i + 1] = pad_rear + + return out_padding + + +def _filter2d(input, kernel): + # prepare kernel + b, c, h, w = input.shape + tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) + + tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) + + height, width = tmp_kernel.shape[-2:] + + padding_shape: list[int] = _compute_padding([height, width]) + input = torch.nn.functional.pad(input, padding_shape, mode="reflect") + + # kernel and input tensor reshape to align element-wise or batch-wise params + tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) + input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) + + # convolve the tensor with the kernel. + output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) + + out = output.view(b, c, h, w) + return out + + +def _gaussian(window_size: int, sigma): + if isinstance(sigma, float): + sigma = torch.tensor([[sigma]]) + + batch_size = sigma.shape[0] + + x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) + + if window_size % 2 == 0: + x = x + 0.5 + + gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) + + return gauss / gauss.sum(-1, keepdim=True) + + +def _gaussian_blur2d(input, kernel_size, sigma): + if isinstance(sigma, tuple): + sigma = torch.tensor([sigma], dtype=input.dtype) + else: + sigma = sigma.to(dtype=input.dtype) + + ky, kx = int(kernel_size[0]), int(kernel_size[1]) + bs = sigma.shape[0] + kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1)) + kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1)) + out_x = _filter2d(input, kernel_x[..., None, :]) + out = _filter2d(out_x, kernel_y[..., None]) + + return out diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index a99135300d92..6aa994676577 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -323,8 +323,20 @@ def _sigma_to_alpha_sigma_t(self, sigma): def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" - sigma_min: float = in_sigmas[-1].item() - sigma_max: float = in_sigmas[0].item() + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index b427f19e9e03..4b638547b38a 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -358,8 +358,20 @@ def _sigma_to_alpha_sigma_t(self, sigma): def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" - sigma_min: float = in_sigmas[-1].item() - sigma_max: float = in_sigmas[0].item() + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index bc8ee24a901c..e762c0ec8bba 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -358,8 +358,20 @@ def _sigma_to_alpha_sigma_t(self, sigma): def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" - sigma_min: float = in_sigmas[-1].item() - sigma_max: float = in_sigmas[0].item() + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 6fd4d3bbf7b6..2c0be3b842cc 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -357,8 +357,20 @@ def _sigma_to_alpha_sigma_t(self, sigma): def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" - sigma_min: float = in_sigmas[-1].item() - sigma_max: float = in_sigmas[0].item() + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 59d9af9f55b6..53dc2ae15432 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -144,7 +144,10 @@ def __init__( prediction_type: str = "epsilon", interpolation_type: str = "linear", use_karras_sigmas: Optional[bool] = False, + sigma_min: Optional[float] = None, + sigma_max: Optional[float] = None, timestep_spacing: str = "linspace", + timestep_type: str = "discrete", # can be "discrete" or "continuous" steps_offset: int = 0, ): if trained_betas is not None: @@ -164,13 +167,22 @@ def __init__( self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) - self.sigmas = torch.from_numpy(sigmas) + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() + + sigmas = torch.from_numpy(sigmas[::-1].copy()).to(dtype=torch.float32) + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) # setable values self.num_inference_steps = None - timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() - self.timesteps = torch.from_numpy(timesteps) + + # TODO: Support the full EDM scalings for all prediction types and timestep types + if timestep_type == "continuous" and prediction_type == "v_prediction": + self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]) + else: + self.timesteps = timesteps + + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + self.is_scale_input_called = False self.use_karras_sigmas = use_karras_sigmas @@ -268,10 +280,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) - sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) - self.sigmas = torch.from_numpy(sigmas).to(device=device) + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) - self.timesteps = torch.from_numpy(timesteps).to(device=device) + # TODO: Support the full EDM scalings for all prediction types and timestep types + if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction": + self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(device=device) + else: + self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device) + + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) self._step_index = None def _sigma_to_t(self, sigma, log_sigmas): @@ -301,8 +318,20 @@ def _sigma_to_t(self, sigma, log_sigmas): def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" - sigma_min: float = in_sigmas[-1].item() - sigma_max: float = in_sigmas[0].item() + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) @@ -412,7 +441,7 @@ def step( elif self.config.prediction_type == "epsilon": pred_original_sample = sample - sigma_hat * model_output elif self.config.prediction_type == "v_prediction": - # * c_out + input * c_skip + # denoised = model_output * c_out + input * c_skip pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) else: raise ValueError( diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index 980dbd1bf839..460299cf2ec1 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -303,8 +303,20 @@ def _sigma_to_t(self, sigma, log_sigmas): def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" - sigma_min: float = in_sigmas[-1].item() - sigma_max: float = in_sigmas[0].item() + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index e74dd868d835..aae5a15abca2 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -324,8 +324,20 @@ def _sigma_to_t(self, sigma, log_sigmas): def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" - sigma_min: float = in_sigmas[-1].item() - sigma_max: float = in_sigmas[0].item() + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index ac590e5713ca..3248520aa9a5 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -335,8 +335,20 @@ def _sigma_to_t(self, sigma, log_sigmas): def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" - sigma_min: float = in_sigmas[-1].item() - sigma_max: float = in_sigmas[0].item() + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index a6d82de80b88..d778f37ec059 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -337,8 +337,20 @@ def _sigma_to_alpha_sigma_t(self, sigma): def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" - sigma_min: float = in_sigmas[-1].item() - sigma_max: float = in_sigmas[0].item() + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 360727ab2fc5..c19b15f2f483 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -32,6 +32,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AutoencoderKLTemporalDecoder(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoencoderTiny(metaclass=DummyObject): _backends = ["torch"] @@ -272,6 +287,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class UNetSpatioTemporalConditionModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class VQModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 3386a95eb7d4..a20a66e07f7f 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1172,6 +1172,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class StableVideoDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class TextToVideoSDPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/src/diffusers/utils/export_utils.py b/src/diffusers/utils/export_utils.py index f7744f9d63eb..45aece18b8fd 100644 --- a/src/diffusers/utils/export_utils.py +++ b/src/diffusers/utils/export_utils.py @@ -3,7 +3,7 @@ import struct import tempfile from contextlib import contextmanager -from typing import List +from typing import List, Union import numpy as np import PIL.Image @@ -115,7 +115,9 @@ def export_to_obj(mesh, output_obj_path: str = None): f.writelines("\n".join(combined_data)) -def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str: +def export_to_video( + video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8 +) -> str: if is_opencv_available(): import cv2 else: @@ -123,9 +125,12 @@ def export_to_video(video_frames: List[np.ndarray], output_video_path: str = Non if output_video_path is None: output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name + if isinstance(video_frames[0], PIL.Image.Image): + video_frames = [np.array(frame) for frame in video_frames] + fourcc = cv2.VideoWriter_fourcc(*"mp4v") h, w, c = video_frames[0].shape - video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=8, frameSize=(w, h)) + video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=fps, frameSize=(w, h)) for i in range(len(video_frames)): img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR) video_writer.write(img) diff --git a/tests/models/test_models_unet_spatiotemporal.py b/tests/models/test_models_unet_spatiotemporal.py new file mode 100644 index 000000000000..fa07eaa736ba --- /dev/null +++ b/tests/models/test_models_unet_spatiotemporal.py @@ -0,0 +1,289 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import unittest + +import torch + +from diffusers import UNetSpatioTemporalConditionModel +from diffusers.utils import logging +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.testing_utils import ( + enable_full_determinism, + floats_tensor, + torch_all_close, + torch_device, +) + +from .test_modeling_common import ModelTesterMixin, UNetTesterMixin + + +logger = logging.get_logger(__name__) + +enable_full_determinism() + + +class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): + model_class = UNetSpatioTemporalConditionModel + main_input_name = "sample" + + @property + def dummy_input(self): + batch_size = 2 + num_frames = 2 + num_channels = 4 + sizes = (32, 32) + + noise = floats_tensor((batch_size, num_frames, num_channels) + sizes).to(torch_device) + time_step = torch.tensor([10]).to(torch_device) + encoder_hidden_states = floats_tensor((batch_size, 1, 32)).to(torch_device) + + return { + "sample": noise, + "timestep": time_step, + "encoder_hidden_states": encoder_hidden_states, + "added_time_ids": self._get_add_time_ids(), + } + + @property + def input_shape(self): + return (2, 2, 4, 32, 32) + + @property + def output_shape(self): + return (4, 32, 32) + + @property + def fps(self): + return 6 + + @property + def motion_bucket_id(self): + return 127 + + @property + def noise_aug_strength(self): + return 0.02 + + @property + def addition_time_embed_dim(self): + return 32 + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_out_channels": (32, 64), + "down_block_types": ( + "CrossAttnDownBlockSpatioTemporal", + "DownBlockSpatioTemporal", + ), + "up_block_types": ( + "UpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", + ), + "cross_attention_dim": 32, + "num_attention_heads": 8, + "out_channels": 4, + "in_channels": 4, + "layers_per_block": 2, + "sample_size": 32, + "projection_class_embeddings_input_dim": self.addition_time_embed_dim * 3, + "addition_time_embed_dim": self.addition_time_embed_dim, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def _get_add_time_ids(self, do_classifier_free_guidance=True): + add_time_ids = [self.fps, self.motion_bucket_id, self.noise_aug_strength] + + passed_add_embed_dim = self.addition_time_embed_dim * len(add_time_ids) + expected_add_embed_dim = self.addition_time_embed_dim * 3 + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], device=torch_device) + add_time_ids = add_time_ids.repeat(1, 1) + if do_classifier_free_guidance: + add_time_ids = torch.cat([add_time_ids, add_time_ids]) + + return add_time_ids + + @unittest.skip("Number of Norm Groups is not configurable") + def test_forward_with_norm_groups(self): + pass + + @unittest.skip("Deprecated functionality") + def test_model_attention_slicing(self): + pass + + @unittest.skip("Not supported") + def test_model_with_use_linear_projection(self): + pass + + @unittest.skip("Not supported") + def test_model_with_simple_projection(self): + pass + + @unittest.skip("Not supported") + def test_model_with_class_embeddings_concat(self): + pass + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_enable_works(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + + model.enable_xformers_memory_efficient_attention() + + assert ( + model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__ + == "XFormersAttnProcessor" + ), "xformers is not enabled" + + @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS") + def test_gradient_checkpointing(self): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + + assert not model.is_gradient_checkpointing and model.training + + out = model(**inputs_dict).sample + # run the backwards pass on the model. For backwards pass, for simplicity purpose, + # we won't calculate the loss and rather backprop on out.sum() + model.zero_grad() + + labels = torch.randn_like(out) + loss = (out - labels).mean() + loss.backward() + + # re-instantiate the model now enabling gradient checkpointing + model_2 = self.model_class(**init_dict) + # clone model + model_2.load_state_dict(model.state_dict()) + model_2.to(torch_device) + model_2.enable_gradient_checkpointing() + + assert model_2.is_gradient_checkpointing and model_2.training + + out_2 = model_2(**inputs_dict).sample + # run the backwards pass on the model. For backwards pass, for simplicity purpose, + # we won't calculate the loss and rather backprop on out.sum() + model_2.zero_grad() + loss_2 = (out_2 - labels).mean() + loss_2.backward() + + # compare the output and parameters gradients + self.assertTrue((loss - loss_2).abs() < 1e-5) + named_params = dict(model.named_parameters()) + named_params_2 = dict(model_2.named_parameters()) + for name, param in named_params.items(): + self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5)) + + def test_model_with_num_attention_heads_tuple(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["num_attention_heads"] = (8, 16) + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_model_with_cross_attention_dim_tuple(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["cross_attention_dim"] = (32, 32) + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_gradient_checkpointing_is_applied(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["num_attention_heads"] = (8, 16) + + model_class_copy = copy.copy(self.model_class) + + modules_with_gc_enabled = {} + + # now monkey patch the following function: + # def _set_gradient_checkpointing(self, module, value=False): + # if hasattr(module, "gradient_checkpointing"): + # module.gradient_checkpointing = value + + def _set_gradient_checkpointing_new(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + modules_with_gc_enabled[module.__class__.__name__] = True + + model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new + + model = model_class_copy(**init_dict) + model.enable_gradient_checkpointing() + + EXPECTED_SET = { + "TransformerSpatioTemporalModel", + "CrossAttnDownBlockSpatioTemporal", + "DownBlockSpatioTemporal", + "UpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", + "UNetMidBlockSpatioTemporal", + } + + assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET + assert all(modules_with_gc_enabled.values()), "All modules should be enabled" + + def test_pickle(self): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["num_attention_heads"] = (8, 16) + + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + sample = model(**inputs_dict).sample + + sample_copy = copy.copy(sample) + + assert (sample - sample_copy).abs().max() < 1e-4 diff --git a/tests/models/test_models_vae.py b/tests/models/test_models_vae.py index 83788b836a78..aa755e387b61 100644 --- a/tests/models/test_models_vae.py +++ b/tests/models/test_models_vae.py @@ -23,6 +23,7 @@ from diffusers import ( AsymmetricAutoencoderKL, AutoencoderKL, + AutoencoderKLTemporalDecoder, AutoencoderTiny, ConsistencyDecoderVAE, StableDiffusionPipeline, @@ -248,11 +249,31 @@ def test_output_pretrained(self): ) elif torch_device == "cpu": expected_output_slice = torch.tensor( - [-0.1352, 0.0878, 0.0419, -0.0818, -0.1069, 0.0688, -0.1458, -0.4446, -0.0026] + [ + -0.1352, + 0.0878, + 0.0419, + -0.0818, + -0.1069, + 0.0688, + -0.1458, + -0.4446, + -0.0026, + ] ) else: expected_output_slice = torch.tensor( - [-0.2421, 0.4642, 0.2507, -0.0438, 0.0682, 0.3160, -0.2018, -0.0727, 0.2485] + [ + -0.2421, + 0.4642, + 0.2507, + -0.0438, + 0.0682, + 0.3160, + -0.2018, + -0.0727, + 0.2485, + ] ) self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2)) @@ -364,6 +385,93 @@ def test_ema_training(self): ... +class AutoncoderKLTemporalDecoderFastTests(ModelTesterMixin, unittest.TestCase): + model_class = AutoencoderKLTemporalDecoder + main_input_name = "sample" + base_precision = 1e-2 + + @property + def dummy_input(self): + batch_size = 3 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + num_frames = 3 + + return {"sample": image, "num_frames": num_frames} + + @property + def input_shape(self): + return (3, 32, 32) + + @property + def output_shape(self): + return (3, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_out_channels": [32, 64], + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], + "latent_channels": 4, + "layers_per_block": 2, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_forward_signature(self): + pass + + def test_training(self): + pass + + @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS") + def test_gradient_checkpointing(self): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + + assert not model.is_gradient_checkpointing and model.training + + out = model(**inputs_dict).sample + # run the backwards pass on the model. For backwards pass, for simplicity purpose, + # we won't calculate the loss and rather backprop on out.sum() + model.zero_grad() + + labels = torch.randn_like(out) + loss = (out - labels).mean() + loss.backward() + + # re-instantiate the model now enabling gradient checkpointing + model_2 = self.model_class(**init_dict) + # clone model + model_2.load_state_dict(model.state_dict()) + model_2.to(torch_device) + model_2.enable_gradient_checkpointing() + + assert model_2.is_gradient_checkpointing and model_2.training + + out_2 = model_2(**inputs_dict).sample + # run the backwards pass on the model. For backwards pass, for simplicity purpose, + # we won't calculate the loss and rather backprop on out.sum() + model_2.zero_grad() + loss_2 = (out_2 - labels).mean() + loss_2.backward() + + # compare the output and parameters gradients + self.assertTrue((loss - loss_2).abs() < 1e-5) + named_params = dict(model.named_parameters()) + named_params_2 = dict(model_2.named_parameters()) + for name, param in named_params.items(): + if "post_quant_conv" in name: + continue + + self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5)) + + @slow class AutoencoderTinyIntegrationTests(unittest.TestCase): def tearDown(self): @@ -609,7 +717,10 @@ def test_stable_diffusion_decode_fp16(self, seed, expected_slice): @parameterized.expand([(13,), (16,), (27,)]) @require_torch_gpu - @unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.") + @unittest.skipIf( + not is_xformers_available(), + reason="xformers is not required when using PyTorch 2.0.", + ) def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed): model = self.get_sd_vae_model(fp16=True) encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True) @@ -627,7 +738,10 @@ def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed): @parameterized.expand([(13,), (16,), (37,)]) @require_torch_gpu - @unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.") + @unittest.skipIf( + not is_xformers_available(), + reason="xformers is not required when using PyTorch 2.0.", + ) def test_stable_diffusion_decode_xformers_vs_2_0(self, seed): model = self.get_sd_vae_model() encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64)) @@ -808,7 +922,10 @@ def test_stable_diffusion_decode(self, seed, expected_slice): @parameterized.expand([(13,), (16,), (37,)]) @require_torch_gpu - @unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.") + @unittest.skipIf( + not is_xformers_available(), + reason="xformers is not required when using PyTorch 2.0.", + ) def test_stable_diffusion_decode_xformers_vs_2_0(self, seed): model = self.get_sd_vae_model() encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64)) @@ -886,7 +1003,10 @@ def test_sd(self): pipe.to(torch_device) out = pipe( - "horse", num_inference_steps=2, output_type="pt", generator=torch.Generator("cpu").manual_seed(0) + "horse", + num_inference_steps=2, + output_type="pt", + generator=torch.Generator("cpu").manual_seed(0), ).images[0] actual_output = out[:2, :2, :2].flatten().cpu() @@ -916,7 +1036,8 @@ def test_encode_decode_f16(self): actual_output = sample[0, :2, :2, :2].flatten().cpu() expected_output = torch.tensor( - [-0.0111, -0.0125, -0.0017, -0.0007, 0.1257, 0.1465, 0.1450, 0.1471], dtype=torch.float16 + [-0.0111, -0.0125, -0.0017, -0.0007, 0.1257, 0.1465, 0.1450, 0.1471], + dtype=torch.float16, ) assert torch_all_close(actual_output, expected_output, atol=5e-3) @@ -926,17 +1047,24 @@ def test_sd_f16(self): "openai/consistency-decoder", torch_dtype=torch.float16 ) # TODO - update pipe = StableDiffusionPipeline.from_pretrained( - "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, vae=vae, safety_checker=None + "runwayml/stable-diffusion-v1-5", + torch_dtype=torch.float16, + vae=vae, + safety_checker=None, ) pipe.to(torch_device) out = pipe( - "horse", num_inference_steps=2, output_type="pt", generator=torch.Generator("cpu").manual_seed(0) + "horse", + num_inference_steps=2, + output_type="pt", + generator=torch.Generator("cpu").manual_seed(0), ).images[0] actual_output = out[:2, :2, :2].flatten().cpu() expected_output = torch.tensor( - [0.0000, 0.0249, 0.0000, 0.0000, 0.1709, 0.2773, 0.0471, 0.1035], dtype=torch.float16 + [0.0000, 0.0249, 0.0000, 0.0000, 0.1709, 0.2773, 0.0471, 0.1035], + dtype=torch.float16, ) assert torch_all_close(actual_output, expected_output, atol=5e-3) diff --git a/tests/pipelines/stable_video_diffusion/__init__.py b/tests/pipelines/stable_video_diffusion/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py b/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py new file mode 100644 index 000000000000..11978424368f --- /dev/null +++ b/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py @@ -0,0 +1,523 @@ +import gc +import random +import tempfile +import unittest + +import numpy as np +import torch +from transformers import ( + CLIPImageProcessor, + CLIPVisionConfig, + CLIPVisionModelWithProjection, +) + +import diffusers +from diffusers import ( + AutoencoderKLTemporalDecoder, + EulerDiscreteScheduler, + StableVideoDiffusionPipeline, + UNetSpatioTemporalConditionModel, +) +from diffusers.utils import is_accelerate_available, is_accelerate_version, load_image, logging +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.testing_utils import ( + CaptureLogger, + disable_full_determinism, + enable_full_determinism, + floats_tensor, + numpy_cosine_similarity_distance, + require_torch_gpu, + slow, + torch_device, +) + +from ..test_pipelines_common import PipelineTesterMixin + + +def to_np(tensor): + if isinstance(tensor, torch.Tensor): + tensor = tensor.detach().cpu().numpy() + + return tensor + + +class StableVideoDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = StableVideoDiffusionPipeline + params = frozenset(["image"]) + batch_params = frozenset(["image", "generator"]) + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + ] + ) + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNetSpatioTemporalConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=8, + out_channels=4, + down_block_types=( + "CrossAttnDownBlockSpatioTemporal", + "DownBlockSpatioTemporal", + ), + up_block_types=("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal"), + cross_attention_dim=32, + num_attention_heads=8, + projection_class_embeddings_input_dim=96, + addition_time_embed_dim=32, + ) + scheduler = EulerDiscreteScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + interpolation_type="linear", + num_train_timesteps=1000, + prediction_type="v_prediction", + sigma_max=700.0, + sigma_min=0.002, + steps_offset=1, + timestep_spacing="leading", + timestep_type="continuous", + trained_betas=None, + use_karras_sigmas=True, + ) + + torch.manual_seed(0) + vae = AutoencoderKLTemporalDecoder( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + latent_channels=4, + ) + + torch.manual_seed(0) + config = CLIPVisionConfig( + hidden_size=32, + projection_dim=32, + num_hidden_layers=5, + num_attention_heads=4, + image_size=32, + intermediate_size=37, + patch_size=1, + ) + image_encoder = CLIPVisionModelWithProjection(config) + + torch.manual_seed(0) + feature_extractor = CLIPImageProcessor(crop_size=32, size=32) + components = { + "unet": unet, + "image_encoder": image_encoder, + "scheduler": scheduler, + "vae": vae, + "feature_extractor": feature_extractor, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + image = floats_tensor((1, 3, 32, 32), rng=random.Random(0)).to(device) + inputs = { + "generator": generator, + "image": image, + "num_inference_steps": 2, + "output_type": "pt", + "min_guidance_scale": 1.0, + "max_guidance_scale": 2.5, + "num_frames": 2, + "height": 32, + "width": 32, + } + return inputs + + @unittest.skip("Deprecated functionality") + def test_attention_slicing_forward_pass(self): + pass + + @unittest.skip("Batched inference works and outputs look correct, but the test is failing") + def test_inference_batch_single_identical( + self, + batch_size=2, + expected_max_diff=1e-4, + ): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for components in pipe.components.values(): + if hasattr(components, "set_default_attn_processor"): + components.set_default_attn_processor() + pipe.to(torch_device) + + pipe.set_progress_bar_config(disable=None) + inputs = self.get_dummy_inputs(torch_device) + + # Reset generator in case it is has been used in self.get_dummy_inputs + inputs["generator"] = torch.Generator("cpu").manual_seed(0) + + logger = logging.get_logger(pipe.__module__) + logger.setLevel(level=diffusers.logging.FATAL) + + # batchify inputs + batched_inputs = {} + batched_inputs.update(inputs) + + batched_inputs["generator"] = [torch.Generator("cpu").manual_seed(0) for i in range(batch_size)] + batched_inputs["image"] = torch.cat([inputs["image"]] * batch_size, dim=0) + + output = pipe(**inputs).frames + output_batch = pipe(**batched_inputs).frames + + assert len(output_batch) == batch_size + + max_diff = np.abs(to_np(output_batch[0]) - to_np(output[0])).max() + assert max_diff < expected_max_diff + + @unittest.skip("Test is similar to test_inference_batch_single_identical") + def test_inference_batch_consistent(self): + pass + + def test_dict_tuple_outputs_equivalent(self, expected_max_difference=1e-4): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + output = pipe(**self.get_dummy_inputs(generator_device)).frames[0] + output_tuple = pipe(**self.get_dummy_inputs(generator_device), return_dict=False)[0] + + max_diff = np.abs(to_np(output) - to_np(output_tuple)).max() + self.assertLess(max_diff, expected_max_difference) + + @unittest.skip("Test is currently failing") + def test_float16_inference(self, expected_max_diff=5e-2): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + components = self.get_dummy_components() + pipe_fp16 = self.pipeline_class(**components) + for component in pipe_fp16.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + + pipe_fp16.to(torch_device, torch.float16) + pipe_fp16.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs).frames[0] + + fp16_inputs = self.get_dummy_inputs(torch_device) + output_fp16 = pipe_fp16(**fp16_inputs).frames[0] + + max_diff = np.abs(to_np(output) - to_np(output_fp16)).max() + self.assertLess(max_diff, expected_max_diff, "The outputs of the fp16 and fp32 pipelines are too different.") + + @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") + def test_save_load_float16(self, expected_max_diff=1e-2): + components = self.get_dummy_components() + for name, module in components.items(): + if hasattr(module, "half"): + components[name] = module.to(torch_device).half() + + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs).frames[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16) + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + for name, component in pipe_loaded.components.items(): + if hasattr(component, "dtype"): + self.assertTrue( + component.dtype == torch.float16, + f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.", + ) + + inputs = self.get_dummy_inputs(torch_device) + output_loaded = pipe_loaded(**inputs).frames[0] + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() + self.assertLess( + max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading." + ) + + def test_save_load_optional_components(self, expected_max_difference=1e-4): + if not hasattr(self.pipeline_class, "_optional_components"): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + # set all optional components to None + for optional_component in pipe._optional_components: + setattr(pipe, optional_component, None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output = pipe(**inputs).frames[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=False) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + for optional_component in pipe._optional_components: + self.assertTrue( + getattr(pipe_loaded, optional_component) is None, + f"`{optional_component}` did not stay set to None after loading.", + ) + + inputs = self.get_dummy_inputs(generator_device) + output_loaded = pipe_loaded(**inputs).frames[0] + + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() + self.assertLess(max_diff, expected_max_difference) + + def test_save_load_local(self, expected_max_difference=9e-4): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs).frames[0] + + logger = logging.get_logger("diffusers.pipelines.pipeline_utils") + logger.setLevel(diffusers.logging.INFO) + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=False) + + with CaptureLogger(logger) as cap_logger: + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + + for name in pipe_loaded.components.keys(): + if name not in pipe_loaded._optional_components: + assert name in str(cap_logger) + + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output_loaded = pipe_loaded(**inputs).frames[0] + + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() + self.assertLess(max_diff, expected_max_difference) + + @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices") + def test_to_device(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + pipe.to("cpu") + model_devices = [ + component.device.type for component in pipe.components.values() if hasattr(component, "device") + ] + self.assertTrue(all(device == "cpu" for device in model_devices)) + + output_cpu = pipe(**self.get_dummy_inputs("cpu")).frames[0] + self.assertTrue(np.isnan(output_cpu).sum() == 0) + + pipe.to("cuda") + model_devices = [ + component.device.type for component in pipe.components.values() if hasattr(component, "device") + ] + self.assertTrue(all(device == "cuda" for device in model_devices)) + + output_cuda = pipe(**self.get_dummy_inputs("cuda")).frames[0] + self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0) + + def test_to_dtype(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] + self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) + + pipe.to(torch_dtype=torch.float16) + model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] + self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) + + @unittest.skipIf( + torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"), + reason="CPU offload is only available with CUDA and `accelerate v0.14.0` or higher", + ) + def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_offload = pipe(**inputs).frames[0] + + pipe.enable_sequential_cpu_offload() + + inputs = self.get_dummy_inputs(generator_device) + output_with_offload = pipe(**inputs).frames[0] + + max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max() + self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results") + + @unittest.skipIf( + torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.17.0"), + reason="CPU offload is only available with CUDA and `accelerate v0.17.0` or higher", + ) + def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4): + generator_device = "cpu" + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(generator_device) + output_without_offload = pipe(**inputs).frames[0] + + pipe.enable_model_cpu_offload() + inputs = self.get_dummy_inputs(generator_device) + output_with_offload = pipe(**inputs).frames[0] + + max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max() + self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results") + offloaded_modules = [ + v + for k, v in pipe.components.items() + if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload + ] + ( + self.assertTrue(all(v.device.type == "cpu" for v in offloaded_modules)), + f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'cpu']}", + ) + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + disable_full_determinism() + + expected_max_diff = 9e-4 + + if not self.test_xformers_attention: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output_without_offload = pipe(**inputs).frames[0] + output_without_offload = ( + output_without_offload.cpu() if torch.is_tensor(output_without_offload) else output_without_offload + ) + + pipe.enable_xformers_memory_efficient_attention() + inputs = self.get_dummy_inputs(torch_device) + output_with_offload = pipe(**inputs).frames[0] + output_with_offload = ( + output_with_offload.cpu() if torch.is_tensor(output_with_offload) else output_without_offload + ) + + max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max() + self.assertLess(max_diff, expected_max_diff, "XFormers attention should not affect the inference results") + + enable_full_determinism() + + +@slow +@require_torch_gpu +class StableVideoDiffusionPipelineSlowTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_sd_video(self): + pipe = StableVideoDiffusionPipeline.from_pretrained( + "stabilityai/stable-video-diffusion-img2vid", + variant="fp16", + torch_dtype=torch.float16, + ) + pipe = pipe.to(torch_device) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true" + ) + + generator = torch.Generator("cpu").manual_seed(0) + num_frames = 3 + + output = pipe( + image=image, + num_frames=num_frames, + generator=generator, + num_inference_steps=3, + output_type="np", + ) + + image = output.frames[0] + assert image.shape == (num_frames, 576, 1024, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.8592, 0.8645, 0.8499, 0.8722, 0.8769, 0.8421, 0.8557, 0.8528, 0.8285]) + assert numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice.flatten()) < 1e-3 diff --git a/tests/schedulers/test_scheduler_euler.py b/tests/schedulers/test_scheduler_euler.py index fa885a0542eb..3249d7032bad 100644 --- a/tests/schedulers/test_scheduler_euler.py +++ b/tests/schedulers/test_scheduler_euler.py @@ -37,6 +37,14 @@ def test_prediction_type(self): for prediction_type in ["epsilon", "v_prediction"]: self.check_over_configs(prediction_type=prediction_type) + def test_timestep_type(self): + timestep_types = ["discrete", "continuous"] + for timestep_type in timestep_types: + self.check_over_configs(timestep_type=timestep_type) + + def test_karras_sigmas(self): + self.check_over_configs(use_karras_sigmas=True, sigma_min=0.02, sigma_max=700.0) + def test_full_loop_no_noise(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() diff --git a/tests/schedulers/test_schedulers.py b/tests/schedulers/test_schedulers.py index 8bc95b38cf34..08c5ad5c3a50 100755 --- a/tests/schedulers/test_schedulers.py +++ b/tests/schedulers/test_schedulers.py @@ -352,8 +352,8 @@ def check_over_configs(self, time_step=0, **config): _ = scheduler.scale_model_input(sample, scaled_sigma_max) _ = new_scheduler.scale_model_input(sample, scaled_sigma_max) elif scheduler_class != VQDiffusionScheduler: - _ = scheduler.scale_model_input(sample, 0) - _ = new_scheduler.scale_model_input(sample, 0) + _ = scheduler.scale_model_input(sample, scheduler.timesteps[-1]) + _ = new_scheduler.scale_model_input(sample, scheduler.timesteps[-1]) # Set the seed before step() as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):