diff --git a/examples/research_projects/motionctrl_svd/README.md b/examples/research_projects/motionctrl_svd/README.md new file mode 100644 index 000000000000..e6c310fe77f4 --- /dev/null +++ b/examples/research_projects/motionctrl_svd/README.md @@ -0,0 +1,73 @@ +### MotionCtrl SVD + +[MotionCtrl](https://arxiv.org/abs/2312.03641) is a method that allows flexible control over object and camera movement in video diffusion models. The implementation here is only for [Stable Video Diffusion](https://wzhouxiff.github.io/projects/MotionCtrl/) as presented by the authors. You can find a more implementation-oriented description about it in [this](https://github.com/huggingface/diffusers/issues/6688#issuecomment-1913459070) comment. You can find example results, some useful discussion and MotionCtrl conversion script [here](https://github.com/huggingface/diffusers/pull/6844). + +Paper: https://arxiv.org/abs/2312.03641 +Project site: https://wzhouxiff.github.io/projects/MotionCtrl/ +Colab: https://colab.research.google.com/drive/17xIdW-xWk4hCAIkGq0OfiJYUqwWSPSAz?usp=sharing +YouTube: Feature on [Two Minute Papers](https://youtu.be/2hfPVBDMB-o). + +### Inference + +```py +import torch +from diffusers import DiffusionPipeline +from diffusers.utils import export_to_gif, load_image + +from pipeline_stable_video_motionctrl_diffusion import StableVideoMotionCtrlDiffusionPipeline +from unet_motionctrl import UNetSpatioTemporalConditionMotionCtrlModel + +# Initialize pipeline +ckpt = "a-r-r-o-w/motionctrl-svd" +unet = UNetSpatioTemporalConditionMotionCtrlModel.from_pretrained(ckpt, subfolder="unet", torch_dtype=torch.float16) +pipe = StableVideoMotionCtrlDiffusionPipeline.from_pretrained( + ckpt, + unet=unet, + torch_dtype=torch.float16, + variant="fp16", +).to("cuda") + +# Input image and camera pose +image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png") +camera_pose = [ + [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -0.2, 0.0, 0.0, 1.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -0.28750000000000003, 0.0, 0.0, 1.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -0.37500000000000006, 0.0, 0.0, 1.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -0.4625000000000001, 0.0, 0.0, 1.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -0.55, 0.0, 0.0, 1.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -0.6375000000000002, 0.0, 0.0, 1.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -0.7250000000000001, 0.0, 0.0, 1.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -0.8125000000000002, 0.0, 0.0, 1.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -0.9000000000000001, 0.0, 0.0, 1.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -0.9875000000000003, 0.0, 0.0, 1.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0750000000000002, 0.0, 0.0, 1.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.1625000000000003, 0.0, 0.0, 1.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.2500000000000002, 0.0, 0.0, 1.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.3375000000000001, 0.0, 0.0, 1.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.4250000000000003, 0.0, 0.0, 1.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.5125000000000004, 0.0, 0.0, 1.0, 0.0], +] + +# Set MotionCtrl scale +pipe.unet.set_motionctrl_scale(0.8) + +# Generation (make sure num_frames == len(camera_pose)) +num_frames = 16 +frames = pipe( + image=image, + camera_pose=camera_pose, + num_frames=num_frames, + num_inference_steps=20, + decode_chunk_size=2, + motion_bucket_id=255, + fps=15, + min_guidance_scale=1, + max_guidance_scale=3.5, + generator=torch.Generator().manual_seed(42) +).frames[0] +export_to_gif(frames, f"animation.gif") +``` + +Note that `camera_pose` must be provided for inference. It represents the orientation and position of the camera, and is known as [Camera Projection Matrix](https://en.wikipedia.org/wiki/Camera_matrix). It must be a list of lists where the outer list has length equal to `num_frames` and inner list has length equal to `3x4 = 9`. For some general camera matrices and movements (left, right, up, down, clockwise, anticlockwise, zoom in/out, etc.), refer to [this](https://colab.research.google.com/drive/17xIdW-xWk4hCAIkGq0OfiJYUqwWSPSAz?usp=sharing) notebook. + + diff --git a/examples/research_projects/motionctrl_svd/convert_motionctrl_to_diffusers.py b/examples/research_projects/motionctrl_svd/convert_motionctrl_to_diffusers.py new file mode 100644 index 000000000000..0dac4b4762e4 --- /dev/null +++ b/examples/research_projects/motionctrl_svd/convert_motionctrl_to_diffusers.py @@ -0,0 +1,839 @@ +#!/usr/bin/env python3 +# To be invoked as `python3 scripts/convert_motionctrl_to_diffusers.py` in diffusers root directory + +import argparse + +import torch +import yaml +from safetensors.torch import load_file +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection +from unet_motionctrl import UNetSpatioTemporalConditionMotionCtrlModel +from yaml.loader import FullLoader + +from diffusers import StableVideoDiffusionPipeline +from diffusers.models import AutoencoderKLTemporalDecoder +from diffusers.schedulers import EulerDiscreteScheduler +from diffusers.utils import is_accelerate_available, logging + + +if is_accelerate_available(): + pass + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def create_vae_diffusers_config(original_config, image_size: int): + r""" + Creates a vae config for diffusers based on the config of the LDM. + """ + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["decoder_config"]["params"] + block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + + vae_config = { + "sample_size": image_size, + "in_channels": vae_params["in_channels"], + "out_channels": vae_params["out_ch"], + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "latent_channels": vae_params["z_channels"], + "layers_per_block": vae_params["num_res_blocks"], + } + + return vae_config + + +def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): + r""" + Creates a unet config for diffusers based on the config of the LDM. + """ + if controlnet: + unet_params = original_config["model"]["params"]["control_stage_config"]["params"] + else: + if ( + "unet_config" in original_config["model"]["params"] + and original_config["model"]["params"]["unet_config"] is not None + ): + unet_params = original_config["model"]["params"]["unet_config"]["params"] + else: + unet_params = original_config["model"]["params"]["network_config"]["params"] + + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["decoder_config"]["params"] + + block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = ( + "CrossAttnDownBlockSpatioTemporal" + if resolution in unet_params["attention_resolutions"] + else "DownBlockSpatioTemporal" + ) + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = ( + "CrossAttnUpBlockSpatioTemporal" + if resolution in unet_params["attention_resolutions"] + else "UpBlockSpatioTemporal" + ) + up_block_types.append(block_type) + resolution //= 2 + + if unet_params["transformer_depth"] is not None: + transformer_layers_per_block = ( + unet_params["transformer_depth"] + if isinstance(unet_params["transformer_depth"], int) + else list(unet_params["transformer_depth"]) + ) + else: + transformer_layers_per_block = 1 + + vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1) + + head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None + use_linear_projection = ( + unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"] + head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])] + + class_embed_type = None + addition_embed_type = None + addition_time_embed_dim = None + projection_class_embeddings_input_dim = None + context_dim = None + + if unet_params["context_dim"] is not None: + context_dim = ( + unet_params["context_dim"] + if isinstance(unet_params["context_dim"], int) + else unet_params["context_dim"][0] + ) + + if "num_classes" in unet_params: + if unet_params["num_classes"] == "sequential": + addition_time_embed_dim = 256 + assert "adm_in_channels" in unet_params + projection_class_embeddings_input_dim = unet_params["adm_in_channels"] + + config = { + "sample_size": image_size // vae_scale_factor, + "in_channels": unet_params["in_channels"], + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params["num_res_blocks"], + "cross_attention_dim": context_dim, + "attention_head_dim": head_dim, + "use_linear_projection": use_linear_projection, + "class_embed_type": class_embed_type, + "addition_embed_type": addition_embed_type, + "addition_time_embed_dim": addition_time_embed_dim, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + "transformer_layers_per_block": transformer_layers_per_block, + } + + if "disable_self_attentions" in unet_params: + config["only_cross_attention"] = unet_params["disable_self_attentions"] + + if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int): + config["num_class_embeds"] = unet_params["num_classes"] + + if controlnet: + config["conditioning_channels"] = unet_params["hint_channels"] + else: + config["out_channels"] = unet_params["out_channels"] + config["up_block_types"] = tuple(up_block_types) + + return config + + +def assign_to_checkpoint( + paths, + checkpoint, + old_checkpoint, + attention_paths_to_split=None, + additional_replacements=None, + config=None, + mid_block_suffix="", +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + if mid_block_suffix is not None: + mid_block_suffix = f".{mid_block_suffix}" + else: + mid_block_suffix = "" + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", f"mid_block.resnets.0{mid_block_suffix}") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", f"mid_block.resnets.1{mid_block_suffix}") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) + shape = old_checkpoint[path["old"]].shape + if is_attn_weight and len(shape) == 3: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + elif is_attn_weight and len(shape) == 4: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("time_stack", "temporal_transformer_blocks") + + new_item = new_item.replace("time_pos_embed.0.bias", "time_pos_embed.linear_1.bias") + new_item = new_item.replace("time_pos_embed.0.weight", "time_pos_embed.linear_1.weight") + new_item = new_item.replace("time_pos_embed.2.bias", "time_pos_embed.linear_2.bias") + new_item = new_item.replace("time_pos_embed.2.weight", "time_pos_embed.linear_2.weight") + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = new_item.replace("time_stack.", "") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def convert_ldm_unet_checkpoint( + checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False +): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + if skip_extract_state_dict: + unet_state_dict = checkpoint + else: + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + unet_key = "model.diffusion_model." + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.") + logger.warning( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + logger.warning( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + spatial_resnets = [ + key + for key in input_blocks[i] + if f"input_blocks.{i}.0" in key + and ( + f"input_blocks.{i}.0.op" not in key + and f"input_blocks.{i}.0.time_stack" not in key + and f"input_blocks.{i}.0.time_mixer" not in key + ) + ] + temporal_resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0.time_stack" in key] + # import ipdb; ipdb.set_trace() + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(spatial_resnets) + meta_path = { + "old": f"input_blocks.{i}.0", + "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}.spatial_res_block", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + paths = renew_resnet_paths(temporal_resnets) + meta_path = { + "old": f"input_blocks.{i}.0", + "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}.temporal_res_block", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + # TODO resnet time_mixer.mix_factor + if f"input_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict: + new_checkpoint[ + f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor" + ] = unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + # import ipdb; ipdb.set_trace() + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_spatial = [key for key in resnet_0 if "time_stack" not in key and "time_mixer" not in key] + resnet_0_paths = renew_resnet_paths(resnet_0_spatial) + # import ipdb; ipdb.set_trace() + assign_to_checkpoint( + resnet_0_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="spatial_res_block" + ) + + resnet_0_temporal = [key for key in resnet_0 if "time_stack" in key and "time_mixer" not in key] + resnet_0_paths = renew_resnet_paths(resnet_0_temporal) + assign_to_checkpoint( + resnet_0_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="temporal_res_block" + ) + + resnet_1_spatial = [key for key in resnet_1 if "time_stack" not in key and "time_mixer" not in key] + resnet_1_paths = renew_resnet_paths(resnet_1_spatial) + assign_to_checkpoint( + resnet_1_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="spatial_res_block" + ) + + resnet_1_temporal = [key for key in resnet_1 if "time_stack" in key and "time_mixer" not in key] + resnet_1_paths = renew_resnet_paths(resnet_1_temporal) + assign_to_checkpoint( + resnet_1_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="temporal_res_block" + ) + + new_checkpoint["mid_block.resnets.0.time_mixer.mix_factor"] = unet_state_dict[ + "middle_block.0.time_mixer.mix_factor" + ] + new_checkpoint["mid_block.resnets.1.time_mixer.mix_factor"] = unet_state_dict[ + "middle_block.2.time_mixer.mix_factor" + ] + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + spatial_resnets = [ + key + for key in output_blocks[i] + if f"output_blocks.{i}.0" in key + and (f"output_blocks.{i}.0.time_stack" not in key and "time_mixer" not in key) + ] + # import ipdb; ipdb.set_trace() + + temporal_resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0.time_stack" in key] + + paths = renew_resnet_paths(spatial_resnets) + meta_path = { + "old": f"output_blocks.{i}.0", + "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}.spatial_res_block", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + paths = renew_resnet_paths(temporal_resnets) + meta_path = { + "old": f"output_blocks.{i}.0", + "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}.temporal_res_block", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if f"output_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict: + new_checkpoint[ + f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor" + ] = unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"] + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key and "conv" not in key] + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + spatial_layers = [ + layer for layer in output_block_layers if "time_stack" not in layer and "time_mixer" not in layer + ] + resnet_0_paths = renew_resnet_paths(spatial_layers, n_shave_prefix_segments=1) + # import ipdb; ipdb.set_trace() + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join( + ["up_blocks", str(block_id), "resnets", str(layer_in_block_id), "spatial_res_block", path["new"]] + ) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + temporal_layers = [ + layer for layer in output_block_layers if "time_stack" in layer and "time_mixer" not in key + ] + resnet_0_paths = renew_resnet_paths(temporal_layers, n_shave_prefix_segments=1) + # import ipdb; ipdb.set_trace() + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join( + ["up_blocks", str(block_id), "resnets", str(layer_in_block_id), "temporal_res_block", path["new"]] + ) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + new_checkpoint["up_blocks.0.resnets.0.time_mixer.mix_factor"] = unet_state_dict[ + f"output_blocks.{str(i)}.0.time_mixer.mix_factor" + ] + + return new_checkpoint + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["to_q.weight", "to_k.weight", "to_v.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0, is_temporal=False): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # Temporal resnet + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = new_item.replace("time_stack.", "temporal_res_block.") + + # Spatial resnet + new_item = new_item.replace("conv1", "spatial_res_block.conv1") + new_item = new_item.replace("norm1", "spatial_res_block.norm1") + + new_item = new_item.replace("conv2", "spatial_res_block.conv2") + new_item = new_item.replace("norm2", "spatial_res_block.norm2") + + new_item = new_item.replace("nin_shortcut", "spatial_res_block.conv_shortcut") + + new_item = new_item.replace("mix_factor", "spatial_res_block.time_mixer.mix_factor") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") + + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") + + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + keys = list(checkpoint.keys()) + vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else "" + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + new_checkpoint["decoder.time_conv_out.weight"] = vae_state_dict["decoder.conv_out.time_mix_conv.weight"] + new_checkpoint["decoder.time_conv_out.bias"] = vae_state_dict["decoder.conv_out.time_mix_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +def read_config_file(filename): + # The yaml file contains annotations that certain values should + # loaded as tuples. + with open(filename) as f: + original_config = yaml.load(f, FullLoader) + + return original_config + + +def load_original_state_dict(filename: str): + if filename.endswith("safetensors"): + state_dict = load_file(filename) + elif filename.endswith("ckpt"): + state_dict = torch.load(filename, mmap=True, map_location="cpu") + else: + raise ValueError("File type is not supported") + + if isinstance(state_dict, dict) and "state_dict" in state_dict.keys(): + state_dict = state_dict["state_dict"] + + return state_dict + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--checkpoint_path", type=str, help="Path to the checkpoint to convert.", required=True) + parser.add_argument( + "--config_file", type=str, help="The config json file corresponding to the architecture.", required=True + ) + parser.add_argument("--output_path", default=None, type=str, help="Path to the output model.", required=True) + parser.add_argument("--sample_size", type=int, default=768, help="VAE sample size") + parser.add_argument( + "--push_to_hub", action="store_true", default=False, help="Whether to push to huggingface hub or not." + ) + args = parser.parse_args() + + original_config = read_config_file(args.config_file) + state_dict = load_original_state_dict(args.checkpoint_path) + + vae_config = create_vae_diffusers_config(original_config, args.sample_size) + vae = AutoencoderKLTemporalDecoder(**vae_config, use_quant_conv=False) + vae_state_dict = convert_ldm_vae_checkpoint(state_dict, vae_config) + + remove = [] + for key in vae_state_dict.keys(): + # i'm sorry to hurt your eyes + if ("encoder" in key) or ( + "decoder" in key and "resnets" and (("temporal_res_block" in key) or ("time_mixer" in key)) + ): + remove.append(key) + + for key in remove: + vae_state_dict[key.replace("spatial_res_block.", "")] = vae_state_dict.pop(key) + + missing_keys, unexpected_keys = vae.load_state_dict(vae_state_dict) + logger.info(f"[VAE] missing_keys: {missing_keys}") + logger.info(f"[VAE] unexpected_keys: {unexpected_keys}") + + unet_config = create_unet_diffusers_config(original_config, args.sample_size) + unet_config["motionctrl_kwargs"] = { + "camera_pose_embed_dim": 1, + "camera_pose_dim": 12, + } + unet_state_dict = convert_ldm_unet_checkpoint(state_dict, unet_config) + unet = UNetSpatioTemporalConditionMotionCtrlModel.from_config(unet_config) + missing_keys, unexpected_keys = unet.load_state_dict(unet_state_dict) + logger.info(f"[UNet] missing_keys: {missing_keys}") + logger.info(f"[UNet] unexpected_keys: {unexpected_keys}") + logger.info("UNet conversion succeeded") + + original_svd_model_id = "stabilityai/stable-video-diffusion-img2vid-xt" + image_encoder = CLIPVisionModelWithProjection.from_pretrained(original_svd_model_id, subfolder="image_encoder") + feature_extractor = CLIPImageProcessor() + scheduler = EulerDiscreteScheduler.from_pretrained(original_svd_model_id, subfolder="scheduler") + + pipe = StableVideoDiffusionPipeline( + vae=vae, + image_encoder=image_encoder, + unet=unet, + scheduler=scheduler, + feature_extractor=feature_extractor, + ) + + pipe.save_pretrained(args.output_path) + if args.push_to_hub: + logger.info("Pushing float32 version to HF hub") + pipe.push_to_hub(args.output_path) + + pipe.to(dtype=torch.float16) + pipe.save_pretrained(args.output_path, variant="fp16") + if args.push_to_hub: + logger.info("Pushing float16 version to HF hub") + pipe.push_to_hub(args.output_path, variant="fp16") diff --git a/examples/research_projects/motionctrl_svd/pipeline_stable_video_motionctrl_diffusion.py b/examples/research_projects/motionctrl_svd/pipeline_stable_video_motionctrl_diffusion.py new file mode 100644 index 000000000000..9f5f0425d499 --- /dev/null +++ b/examples/research_projects/motionctrl_svd/pipeline_stable_video_motionctrl_diffusion.py @@ -0,0 +1,646 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MotionCtrl for Stable Video Diffusion +# +# Paper: https://arxiv.org/pdf/2312.03641.pdf +# Authors: Zhouxia Wang, Ziyang Yuan, Xintao Wang, Tianshui Chen, Menghan Xia, Ping Luo, Ying Shan +# Project Page: https://wzhouxiff.github.io/projects/MotionCtrl/ +# Code: https://github.com/TencentARC/MotionCtrl +# +# Adapted to diffusers by [Aryan V S](https://github.com/a-r-r-o-w). + +import inspect +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection +from unet_motionctrl import UNetSpatioTemporalConditionMotionCtrlModel + +from diffusers.image_processor import VaeImageProcessor +from diffusers.models import AutoencoderKLTemporalDecoder +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import ( + StableVideoDiffusionPipelineOutput, + _resize_with_antialiasing, +) +from diffusers.schedulers import EulerDiscreteScheduler +from diffusers.utils import logging, replace_example_docstring +from diffusers.utils.torch_utils import is_compiled_module, randn_tensor + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import DiffusionPipeline + >>> from diffusers.utils import export_to_gif, load_image + >>> from pipeline_stable_video_motionctrl_diffusion import StableVideoMotionCtrlDiffusionPipeline + >>> from unet_motionctrl import UNetSpatioTemporalConditionMotionCtrlModel + + >>> # Initialize pipeline + >>> ckpt = "a-r-r-o-w/motionctrl-svd" + >>> unet = UNetSpatioTemporalConditionMotionCtrlModel.from_pretrained(ckpt, subfolder="unet", torch_dtype=torch.float16) + >>> pipe = StableVideoMotionCtrlDiffusionPipeline.from_pretrained( + ... ckpt, + ... unet=unet, + ... torch_dtype=torch.float16, + ... variant="fp16", + >>> ).to("cuda") + + >>> # Input image and camera pose + >>> image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png") + >>> camera_pose = [ + ... [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -0.2, 0.0, 0.0, 1.0, 0.0], + ... [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -0.28750000000000003, 0.0, 0.0, 1.0, 0.0], + ... [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -0.37500000000000006, 0.0, 0.0, 1.0, 0.0], + ... [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -0.4625000000000001, 0.0, 0.0, 1.0, 0.0], + ... [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -0.55, 0.0, 0.0, 1.0, 0.0], + ... [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -0.6375000000000002, 0.0, 0.0, 1.0, 0.0], + ... [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -0.7250000000000001, 0.0, 0.0, 1.0, 0.0], + ... [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -0.8125000000000002, 0.0, 0.0, 1.0, 0.0], + ... [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -0.9000000000000001, 0.0, 0.0, 1.0, 0.0], + ... [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -0.9875000000000003, 0.0, 0.0, 1.0, 0.0], + ... [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0750000000000002, 0.0, 0.0, 1.0, 0.0], + ... [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.1625000000000003, 0.0, 0.0, 1.0, 0.0], + ... [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.2500000000000002, 0.0, 0.0, 1.0, 0.0], + ... [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.3375000000000001, 0.0, 0.0, 1.0, 0.0], + ... [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.4250000000000003, 0.0, 0.0, 1.0, 0.0], + ... [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.5125000000000004, 0.0, 0.0, 1.0, 0.0], + >>> ] + + >>> # Set MotionCtrl scale + >>> pipe.unet.set_motionctrl_scale(0.8) + + >>> # Generation (make sure num_frames == len(camera_pose)) + >>> num_frames = 16 + >>> frames = pipe( + ... image=image, + ... camera_pose=camera_pose, + ... num_frames=num_frames, + ... num_inference_steps=20, + ... decode_chunk_size=2, + ... motion_bucket_id=255, + ... fps=15, + ... min_guidance_scale=3.5, + ... max_guidance_scale=1, + ... generator=torch.Generator().manual_seed(42) + >>> ).frames[0] + >>> export_to_gif(frames, f"animation.gif") + ``` +""" + + +def _append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") + return x[(...,) + (None,) * dims_to_append] + + +# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid +def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): + batch_size, channels, num_frames, height, width = video.shape + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = processor.postprocess(batch_vid, output_type) + + outputs.append(batch_output) + + if output_type == "np": + outputs = np.stack(outputs) + + elif output_type == "pt": + outputs = torch.stack(outputs) + + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") + + return outputs + + +class StableVideoMotionCtrlDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline to generate video from an input image using [MotionCtrl](https://github.com/TencentARC/MotionCtrl). + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKLTemporalDecoder`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)). + unet ([`UNetSpatioTemporalConditionModel`]): + A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents. + scheduler ([`EulerDiscreteScheduler`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images. + """ + + model_cpu_offload_seq = "image_encoder->unet->vae" + _callback_tensor_inputs = ["latents"] + + def __init__( + self, + vae: AutoencoderKLTemporalDecoder, + image_encoder: CLIPVisionModelWithProjection, + unet: UNetSpatioTemporalConditionMotionCtrlModel, + scheduler: EulerDiscreteScheduler, + feature_extractor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + vae=vae, + image_encoder=image_encoder, + unet=unet, + scheduler=scheduler, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion.StableVideoDiffusionPipeline._encode_image + def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.image_processor.pil_to_numpy(image) + image = self.image_processor.numpy_to_pt(image) + + # We normalize the image before resizing to match with the original implementation. + # Then we unnormalize it after resizing. + image = image * 2.0 - 1.0 + image = _resize_with_antialiasing(image, (224, 224)) + image = (image + 1.0) / 2.0 + + # Normalize the image with for CLIP input + image = self.feature_extractor( + images=image, + do_normalize=True, + do_center_crop=False, + do_resize=False, + do_rescale=False, + return_tensors="pt", + ).pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds + image_embeddings = image_embeddings.unsqueeze(1) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + negative_image_embeddings = torch.zeros_like(image_embeddings) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_image_embeddings, image_embeddings]) + + return image_embeddings + + # Copied from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion.StableVideoDiffusionPipeline._encode_vae_image + def _encode_vae_image( + self, + image: torch.Tensor, + device, + num_videos_per_prompt, + do_classifier_free_guidance, + ): + image = image.to(device=device) + image_latents = self.vae.encode(image).latent_dist.mode() + + if do_classifier_free_guidance: + negative_image_latents = torch.zeros_like(image_latents) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_latents = torch.cat([negative_image_latents, image_latents]) + + # duplicate image_latents for each generation per prompt, using mps friendly method + image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1) + + return image_latents + + # Copied from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion.StableVideoDiffusionPipeline._get_add_time_ids + def _get_add_time_ids( + self, + fps, + motion_bucket_id, + noise_aug_strength, + dtype, + batch_size, + num_videos_per_prompt, + do_classifier_free_guidance, + ): + add_time_ids = [fps, motion_bucket_id, noise_aug_strength] + + passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1) + + if do_classifier_free_guidance: + add_time_ids = torch.cat([add_time_ids, add_time_ids]) + + return add_time_ids + + # Copied from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion.StableVideoDiffusionPipeline.decode_latents + def decode_latents(self, latents, num_frames, decode_chunk_size=14): + # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width] + latents = latents.flatten(0, 1) + + latents = 1 / self.vae.config.scaling_factor * latents + + forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward + accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys()) + + # decode decode_chunk_size frames at a time to avoid OOM + frames = [] + for i in range(0, latents.shape[0], decode_chunk_size): + num_frames_in = latents[i : i + decode_chunk_size].shape[0] + decode_kwargs = {} + if accepts_num_frames: + # we only pass num_frames_in if it's expected + decode_kwargs["num_frames"] = num_frames_in + + frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample + frames.append(frame) + frames = torch.cat(frames, dim=0) + + # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width] + frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + frames = frames.float() + return frames + + def check_inputs(self, image, height, width, num_frames, camera_pose): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + camera_pose_lengths = [len(pose) for pose in camera_pose] + if len(camera_pose) != num_frames: + raise ValueError(f"length of `camera_poses` must be equal to {num_frames=} but got {len(camera_pose)=}") + if not all(x == 12 for x in camera_pose_lengths): + raise ValueError(f"All camera poses must have 12 values but got {camera_pose_lengths}") + + # Copied from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion.StableVideoDiffusionPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_frames, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size, + num_frames, + num_channels_latents // 2, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _to_relative_camera_pose( + self, camera_pose: np.ndarray, keyframe_index: int = 0, keyframe_zero: bool = False + ) -> np.ndarray: + camera_pose = camera_pose.reshape(-1, 3, 4) + rotation_dst = camera_pose[:, :, :3] + translation_dst = camera_pose[:, :, 3:] + + rotation_src = rotation_dst[keyframe_index : keyframe_index + 1].repeat(camera_pose.shape[0], axis=0) + translation_src = translation_dst[keyframe_index : keyframe_index + 1].repeat(camera_pose.shape[0], axis=0) + + rotation_src_inv = rotation_src.transpose(0, 2, 1) + rotation_rel = rotation_dst @ rotation_src_inv + translation_rel = translation_dst - rotation_rel @ translation_src + + rt_rel = np.concatenate([rotation_rel, translation_rel], axis=-1) + rt_rel = rt_rel.reshape(-1, 12) + + if keyframe_zero: + rt_rel[keyframe_index] = np.zeros_like(rt_rel[keyframe_index]) + + return rt_rel + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + if isinstance(self.guidance_scale, (int, float)): + return self.guidance_scale > 1 + return self.guidance_scale.max() > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], + camera_pose: List[List[float]], + camera_speed: float = 1.0, + height: int = 576, + width: int = 1024, + num_frames: Optional[int] = None, + num_inference_steps: int = 25, + min_guidance_scale: float = 1.0, + max_guidance_scale: float = 3.0, + fps: int = 7, + motion_bucket_id: int = 127, + noise_aug_strength: float = 0.02, + decode_chunk_size: Optional[int] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + return_dict: bool = True, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + Image or images to guide image generation. If you provide a tensor, it needs to be compatible with + [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json). + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_frames (`int`, *optional*): + The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt` + num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + min_guidance_scale (`float`, *optional*, defaults to 1.0): + The minimum guidance scale. Used for the classifier free guidance with first frame. + max_guidance_scale (`float`, *optional*, defaults to 3.0): + The maximum guidance scale. Used for the classifier free guidance with last frame. + fps (`int`, *optional*, defaults to 7): + Frames per second. The rate at which the generated images shall be exported to a video after generation. + Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training. + motion_bucket_id (`int`, *optional*, defaults to 127): + The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video. + noise_aug_strength (`float`, *optional*, defaults to 0.02): + The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion. + decode_chunk_size (`int`, *optional*): + The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency + between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once + for maximal quality. Reduce `decode_chunk_size` to reduce memory usage. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] instead of a + plain tuple. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list of list with the generated frames. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_frames = num_frames if num_frames is not None else self.unet.config.num_frames + decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames + + # 1. Check inputs. Raise error if not correct + self.check_inputs(image, height, width, num_frames, camera_pose) + + # 2. Define call parameters + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + self._guidance_scale = max_guidance_scale + + # 3. Encode input image + image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance) + + # NOTE: Stable Diffusion Video was conditioned on fps - 1, which + # is why it is reduced here. + # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188 + fps = fps - 1 + + # 4. Encode input image using VAE + image = self.image_processor.preprocess(image, height=height, width=width).to(device) + noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype) + image = image + noise_aug_strength * noise + + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + if needs_upcasting: + self.vae.to(dtype=torch.float32) + + image_latents = self._encode_vae_image( + image, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + image_latents = image_latents.to(image_embeddings.dtype) + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + + # Repeat the image latents for each frame so we can concatenate them with the noise + # image_latents [batch, channels, height, width] -> [batch, num_frames, channels, height, width] + image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1) + + # 5. Get Added Time IDs + added_time_ids = self._get_add_time_ids( + fps, + motion_bucket_id, + noise_aug_strength, + image_embeddings.dtype, + batch_size, + num_videos_per_prompt, + self.do_classifier_free_guidance, + ) + added_time_ids = added_time_ids.to(device) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_frames, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 7. Prepare guidance scale + guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0) + guidance_scale = guidance_scale.to(device, latents.dtype) + guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1) + guidance_scale = _append_dims(guidance_scale, latents.ndim) + + self._guidance_scale = guidance_scale + + camera_pose = np.array(camera_pose).reshape(-1, 3, 4) + camera_pose[:, :, -1] = camera_pose[:, :, -1] * camera_speed + camera_pose = self._to_relative_camera_pose(camera_pose) + camera_pose = torch.FloatTensor(camera_pose).to(device=device, dtype=image_embeddings.dtype) + camera_pose = camera_pose.unsqueeze(0) + if self.do_classifier_free_guidance: + camera_pose = camera_pose.repeat(2, 1, 1) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # Concatenate image_latents over channels dimention + latent_model_input = torch.cat([latent_model_input, image_latents], dim=2) + + # predict the noise residual + noise_pred = self.unet( + camera_pose, + latent_model_input, + t, + encoder_hidden_states=image_embeddings, + added_time_ids=added_time_ids, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + frames = self.decode_latents(latents, num_frames, decode_chunk_size) + frames = tensor2vid(frames, self.image_processor, output_type=output_type) + else: + frames = latents + + self.maybe_free_model_hooks() + + if not return_dict: + return frames + + return StableVideoDiffusionPipelineOutput(frames=frames) diff --git a/examples/research_projects/motionctrl_svd/unet_motionctrl.py b/examples/research_projects/motionctrl_svd/unet_motionctrl.py new file mode 100644 index 000000000000..bfddddc5bdb4 --- /dev/null +++ b/examples/research_projects/motionctrl_svd/unet_motionctrl.py @@ -0,0 +1,108 @@ +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn + +from diffusers.models import UNetSpatioTemporalConditionModel +from diffusers.models.attention import TemporalBasicTransformerBlock, _chunked_feed_forward +from diffusers.utils.torch_utils import maybe_allow_in_graph + + +@maybe_allow_in_graph +def _forward_temporal_basic_transformer_block( + self, + camera_pose: torch.FloatTensor, + scale: float, + hidden_states: torch.FloatTensor, + num_frames: int, + encoder_hidden_states: Optional[torch.FloatTensor] = None, +) -> torch.FloatTensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + batch_frames, seq_length, channels = hidden_states.shape + batch_size = batch_frames // num_frames + + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels) + hidden_states = hidden_states.permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels) + + residual = hidden_states + hidden_states = self.norm_in(hidden_states) + + if self._chunk_size is not None: + hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size) + else: + hidden_states = self.ff_in(hidden_states) + + if self.is_res: + hidden_states = hidden_states + residual + + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None) + hidden_states = attn_output + hidden_states + + # MotionCtrl specific + camera_pose = camera_pose.repeat_interleave(seq_length, dim=0) # [batch_size * seq_length, num_frames, 12] + residual = hidden_states + hidden_states = torch.cat([hidden_states, camera_pose], dim=-1) + hidden_states = scale * self.cc_projection(hidden_states) + (1 - scale) * residual + + # 3. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = self.norm2(hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self._chunk_size is not None: + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + if self.is_res: + hidden_states = ff_output + hidden_states + else: + hidden_states = ff_output + + hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels) + hidden_states = hidden_states.permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels) + + return hidden_states + + +class UNetSpatioTemporalConditionMotionCtrlModel(UNetSpatioTemporalConditionModel): + r"""UNetSpatioTemporalConditionModel for [MotionCtrl SVD](https://arxiv.org/abs/2312.03641).""" + + def __init__(self, motionctrl_kwargs: Dict[str, Any], *args, **kwargs): + super().__init__(*args, **kwargs) + self.motionctrl_scale = 1 + self._camera_pose = None + + camera_pose_embed_dim = motionctrl_kwargs.get("camera_pose_embed_dim") + camera_pose_dim = motionctrl_kwargs.get("camera_pose_dim") + + def pre_hook(module, args): + return (self._camera_pose, self.motionctrl_scale, *args) + + for _, module in self.named_modules(): + if isinstance(module, TemporalBasicTransformerBlock): + cc_projection = nn.Linear( + module.time_mix_inner_dim + camera_pose_embed_dim * camera_pose_dim, module.time_mix_inner_dim + ) + module.add_module("cc_projection", cc_projection) + + new_forward = _forward_temporal_basic_transformer_block.__get__(module, module.__class__) + setattr(module, "forward", new_forward) + module.register_forward_pre_hook(pre_hook) + + def set_motionctrl_scale(self, scale: float): + self.motionctrl_scale = scale + + def forward(self, camera_pose: torch.FloatTensor, *args, **kwargs): + self._camera_pose = camera_pose + return super().forward(*args, **kwargs) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 3d4fccb20779..447acfe1eb4a 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -425,6 +425,7 @@ def __init__( cross_attention_dim: Optional[int] = None, ): super().__init__() + self.time_mix_inner_dim = time_mix_inner_dim self.is_res = dim == time_mix_inner_dim self.norm_in = nn.LayerNorm(dim) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py index 67540cb7dc7f..4eb740b8f969 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py @@ -205,6 +205,7 @@ def __init__( sample_size: int = 32, scaling_factor: float = 0.18215, force_upcast: float = True, + use_quant_conv: bool = True, ): super().__init__() @@ -226,7 +227,7 @@ def __init__( layers_per_block=layers_per_block, ) - self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None sample_size = ( self.config.sample_size[0] @@ -330,8 +331,9 @@ def encode( [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ h = self.encoder(x) - moments = self.quant_conv(h) - posterior = DiagonalGaussianDistribution(moments) + if self.quant_conv is not None: + h = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(h) if not return_dict: return (posterior,)