From b14881e38b04ffed1119fa4e4ea9a60294d1cc95 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 23 Jan 2024 11:04:59 +0000 Subject: [PATCH 01/18] update --- .../models/unets/unet_motion_model.py | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 9654ae508215..09501c5b6ce4 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -89,6 +89,8 @@ def __init__( motion_norm_num_groups: int = 32, motion_max_seq_length: int = 32, use_motion_mid_block: bool = True, + use_input_conv: bool = False, + input_conv_channels: int = 9, ): """Container to store AnimateDiff Motion Modules @@ -113,6 +115,16 @@ def __init__( down_blocks = [] up_blocks = [] + if use_input_conv: + # input + conv_in_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + input_conv_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + else: + self.conv_in = None + for i, channel in enumerate(block_out_channels): output_channel = block_out_channels[i] down_blocks.append( @@ -419,7 +431,12 @@ def from_unet2d( if not load_weights: return model - model.conv_in.load_state_dict(unet.conv_in.state_dict()) + if motion_adapter.config["use_input_conv"]: + model.conv_in.weight[:, :4, :, :] = unet.conv_in.weight + model.conv_in.bias = unet.conv_in.bias + else: + model.conv_in.load_state_dict(unet.conv_in.state_dict()) + model.time_proj.load_state_dict(unet.time_proj.state_dict()) model.time_embedding.load_state_dict(unet.time_embedding.state_dict()) @@ -488,6 +505,9 @@ def load_motion_modules(self, motion_adapter: Optional[MotionAdapter]) -> None: if hasattr(self.mid_block, "motion_modules"): self.mid_block.motion_modules.load_state_dict(motion_adapter.mid_block.motion_modules.state_dict()) + if motion_adapter.config["use_input_conv"]: + self.conv_in.load_state_dict(motion_adapter.conv_in.state_dict()) + def save_motion_modules( self, save_directory: str, From 1169370f193766c9cd698446e62476c8f4fad81c Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 24 Jan 2024 04:00:46 +0000 Subject: [PATCH 02/18] update --- src/diffusers/models/unets/unet_motion_model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 09501c5b6ce4..dd0ac0f89fce 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -432,8 +432,12 @@ def from_unet2d( return model if motion_adapter.config["use_input_conv"]: - model.conv_in.weight[:, :4, :, :] = unet.conv_in.weight - model.conv_in.bias = unet.conv_in.bias + model.conv_in = motion_adapter.conv_in + updated_conv_in_weight = torch.cat( + [unet.conv_in.weight, motion_adapter.conv_in.weight[:, 4:, :, :]], dim=1 + ) + model.conv_in.load_state_dict({"weight": updated_conv_in_weight, "bias": unet.conv_in.bias}) + else: model.conv_in.load_state_dict(unet.conv_in.state_dict()) From 7be981cc8cc46dc08f76d03ab1b0337f5dad5e78 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 24 Jan 2024 04:01:08 +0000 Subject: [PATCH 03/18] updaet --- src/diffusers/pipelines/pia/__init__.py | 0 src/diffusers/pipelines/pia/pipeline_pia.py | 1221 +++++++++++++++++++ 2 files changed, 1221 insertions(+) create mode 100644 src/diffusers/pipelines/pia/__init__.py create mode 100644 src/diffusers/pipelines/pia/pipeline_pia.py diff --git a/src/diffusers/pipelines/pia/__init__.py b/src/diffusers/pipelines/pia/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py new file mode 100644 index 000000000000..0caedd8e67be --- /dev/null +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -0,0 +1,1221 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.fft as fft +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...models.unets.unet_motion_model import MotionAdapter +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import ( + USE_PEFT_BACKEND, + BaseOutput, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler + >>> from diffusers.utils import export_to_gif + + >>> adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2") + >>> pipe = AnimateDiffPipeline.from_pretrained("frankjoshua/toonyou_beta6", motion_adapter=adapter) + >>> pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False) + >>> output = pipe(prompt="A corgi walking in the park") + >>> frames = output.frames[0] + >>> export_to_gif(frames, "animation.gif") + ``` +""" + +RANGE_LIST = [ + [1.0, 0.9, 0.85, 0.85, 0.85, 0.8], # 0 Small Motion + [1.0, 0.8, 0.8, 0.8, 0.79, 0.78, 0.75], # Moderate Motion + [1.0, 0.8, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.6, 0.5, 0.5], # Large Motion + [1.0, 0.9, 0.85, 0.85, 0.85, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.85, 0.85, 0.9, 1.0], # Loop + [1.0, 0.8, 0.8, 0.8, 0.79, 0.78, 0.75, 0.75, 0.75, 0.75, 0.75, 0.78, 0.79, 0.8, 0.8, 1.0], # Loop + [1.0, 0.8, 0.7, 0.7, 0.7, 0.7, 0.6, 0.5, 0.5, 0.6, 0.7, 0.7, 0.7, 0.7, 0.8, 1.0], # Loop + [0.5, 0.2], # Style Transfer Large Motion + [0.5, 0.4, 0.4, 0.4, 0.35, 0.35, 0.3, 0.25, 0.2], # Style Transfer Moderate Motion + [0.5, 0.4, 0.4, 0.4, 0.35, 0.3], # Style Transfer Candidate Small Motion +] + + +def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): + batch_size, channels, num_frames, height, width = video.shape + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = processor.postprocess(batch_vid, output_type) + + outputs.append(batch_output) + + if output_type == "np": + outputs = np.stack(outputs) + + elif output_type == "pt": + outputs = torch.stack(outputs) + + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") + + return outputs + + +def prepare_mask_coef_by_statistics(num_frames: int, cond_frame: int, sim_range: int): + assert num_frames > 0, "video_length should be greater than 0" + + assert num_frames > cond_frame, "video_length should be greater than cond_frame" + + range_list = RANGE_LIST + + assert sim_range < len(range_list), f"sim_range type{sim_range} not implemented" + + coef = range_list[sim_range] + coef = coef + ([coef[-1]] * (num_frames - len(coef))) + + order = [abs(i - cond_frame) for i in range(num_frames)] + coef = [coef[order[i]] for i in range(num_frames)] + + return coef + + +def _get_freeinit_freq_filter( + shape: Tuple[int, ...], + device: Union[str, torch.dtype], + filter_type: str, + order: float, + spatial_stop_frequency: float, + temporal_stop_frequency: float, +) -> torch.Tensor: + r"""Returns the FreeInit filter based on filter type and other input conditions.""" + + T, H, W = shape[-3], shape[-2], shape[-1] + mask = torch.zeros(shape) + + if spatial_stop_frequency == 0 or temporal_stop_frequency == 0: + return mask + + if filter_type == "butterworth": + + def retrieve_mask(x): + return 1 / (1 + (x / spatial_stop_frequency**2) ** order) + elif filter_type == "gaussian": + + def retrieve_mask(x): + return math.exp(-1 / (2 * spatial_stop_frequency**2) * x) + elif filter_type == "ideal": + + def retrieve_mask(x): + return 1 if x <= spatial_stop_frequency * 2 else 0 + else: + raise NotImplementedError("`filter_type` must be one of gaussian, butterworth or ideal") + + for t in range(T): + for h in range(H): + for w in range(W): + d_square = ( + ((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / T - 1)) ** 2 + + (2 * h / H - 1) ** 2 + + (2 * w / W - 1) ** 2 + ) + mask[..., t, h, w] = retrieve_mask(d_square) + + return mask.to(device) + + +def _freq_mix_3d(x: torch.Tensor, noise: torch.Tensor, LPF: torch.Tensor) -> torch.Tensor: + r"""Noise reinitialization.""" + # FFT + x_freq = fft.fftn(x, dim=(-3, -2, -1)) + x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1)) + noise_freq = fft.fftn(noise, dim=(-3, -2, -1)) + noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1)) + + # frequency mix + HPF = 1 - LPF + x_freq_low = x_freq * LPF + noise_freq_high = noise_freq * HPF + x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain + + # IFFT + x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1)) + x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real + + return x_mixed + + +@dataclass +class PIAPipelineOutput(BaseOutput): + frames: Union[torch.Tensor, np.ndarray] + + +class PIAPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin): + r""" + Pipeline for text-to-video generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents. + motion_adapter ([`MotionAdapter`]): + A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["feature_extractor", "image_encoder"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + motion_adapter: MotionAdapter, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + ): + super().__init__() + unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + motion_adapter=motion_adapter, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + image = self.vae.decode(latents).sample + video = ( + image[None, :] + .reshape( + ( + batch_size, + num_frames, + -1, + ) + + image.shape[2:] + ) + .permute(0, 2, 1, 3, 4) + ) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + + @property + def free_init_enabled(self): + return hasattr(self, "_free_init_num_iters") and self._free_init_num_iters is not None + + def enable_free_init( + self, + num_iters: int = 3, + use_fast_sampling: bool = False, + method: str = "butterworth", + order: int = 4, + spatial_stop_frequency: float = 0.25, + temporal_stop_frequency: float = 0.25, + generator: torch.Generator = None, + ): + """Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537. + + This implementation has been adapted from the [official repository](https://github.com/TianxingWu/FreeInit). + + Args: + num_iters (`int`, *optional*, defaults to `3`): + Number of FreeInit noise re-initialization iterations. + use_fast_sampling (`bool`, *optional*, defaults to `False`): + Whether or not to speedup sampling procedure at the cost of probably lower quality results. Enables + the "Coarse-to-Fine Sampling" strategy, as mentioned in the paper, if set to `True`. + method (`str`, *optional*, defaults to `butterworth`): + Must be one of `butterworth`, `ideal` or `gaussian` to use as the filtering method for the + FreeInit low pass filter. + order (`int`, *optional*, defaults to `4`): + Order of the filter used in `butterworth` method. Larger values lead to `ideal` method behaviour + whereas lower values lead to `gaussian` method behaviour. + spatial_stop_frequency (`float`, *optional*, defaults to `0.25`): + Normalized stop frequency for spatial dimensions. Must be between 0 to 1. Referred to as `d_s` in + the original implementation. + temporal_stop_frequency (`float`, *optional*, defaults to `0.25`): + Normalized stop frequency for temporal dimensions. Must be between 0 to 1. Referred to as `d_t` in + the original implementation. + generator (`torch.Generator`, *optional*, defaults to `0.25`): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + FreeInit generation deterministic. + """ + self._free_init_num_iters = num_iters + self._free_init_use_fast_sampling = use_fast_sampling + self._free_init_method = method + self._free_init_order = order + self._free_init_spatial_stop_frequency = spatial_stop_frequency + self._free_init_temporal_stop_frequency = temporal_stop_frequency + self._free_init_generator = generator + + def disable_free_init(self): + """Disables the FreeInit mechanism if enabled.""" + self._free_init_num_iters = None + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents( + self, + batch_size, + num_channels_latents, + num_frames, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + + _, _, _, scaled_height, scaled_width = shape + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + latents = latents * self.scheduler.init_noise_sigma + + return latents + + def prepare_masked_condition( + self, + image, + batch_size, + num_channels_latents, + num_frames, + height, + width, + dtype, + device, + generator, + latent_timestep, + strength=1.0, + mask_sim_template_idx=0, + cond_frame=0, + ): + + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + _, _, _, scaled_height, scaled_width = shape + + image = self.image_processor.preprocess(image) + image = image.to(device, dtype) + + if isinstance(generator, list): + image_latent = [ + self.vae.encode(image[k : k + 1]).latent_dist.sample(generator[k]) for k in range(batch_size) + ] + image_latent = torch.cat(image_latent, dim=0) + else: + image_latent = self.vae.encode(image).latent_dist.sample(generator) + + image_latent = image_latent.to(device=device, dtype=dtype) + image_latent = torch.nn.functional.interpolate(image_latent, size=[scaled_height, scaled_width]) + image_latent_padding = image_latent.clone() * 0.18215 + + mask = torch.zeros((batch_size, 1, num_frames, scaled_height, scaled_width)).to(device=device, dtype=dtype) + mask_coef = prepare_mask_coef_by_statistics(num_frames, cond_frame, mask_sim_template_idx) + masked_image = torch.zeros(batch_size, 4, num_frames, scaled_height, scaled_width).to( + device=device, dtype=self.unet.dtype + ) + for f in range(num_frames): + mask[:, :, f, :, :] = mask_coef[f] + masked_image[:, :, f, :, :] = image_latent_padding.clone() + + mask = torch.cat([mask] * 2) if self.do_classifier_free_guidance else mask + masked_image = torch.cat([masked_image] * 2) if self.do_classifier_free_guidance else masked_image + + """ + if strength < 1.0: + noise = torch.randn_like(latents) + latents = self.scheduler.add_noise(masked_image[0], noise, latent_timestep) + """ + #latents = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + return mask, masked_image + + def _denoise_loop( + self, + timesteps, + num_inference_steps, + do_classifier_free_guidance, + guidance_scale, + num_warmup_steps, + prompt_embeds, + negative_prompt_embeds, + latents, + mask, + masked_image, + cross_attention_kwargs, + added_cond_kwargs, + extra_step_kwargs, + callback, + callback_steps, + callback_on_step_end, + callback_on_step_end_tensor_inputs, + ): + """Denoising loop for AnimateDiff.""" + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = torch.cat([latent_model_input, mask, masked_image], dim=1) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + return latents + + def _free_init_loop( + self, + height, + width, + num_frames, + num_channels_latents, + batch_size, + num_videos_per_prompt, + denoise_args, + device, + ): + """Denoising loop for AnimateDiff using FreeInit noise reinitialization technique.""" + + latents = denoise_args.get("latents") + prompt_embeds = denoise_args.get("prompt_embeds") + timesteps = denoise_args.get("timesteps") + num_inference_steps = denoise_args.get("num_inference_steps") + + latent_shape = ( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + free_init_filter_shape = ( + 1, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + free_init_freq_filter = _get_freeinit_freq_filter( + shape=free_init_filter_shape, + device=device, + filter_type=self._free_init_method, + order=self._free_init_order, + spatial_stop_frequency=self._free_init_spatial_stop_frequency, + temporal_stop_frequency=self._free_init_temporal_stop_frequency, + ) + + with self.progress_bar(total=self._free_init_num_iters) as free_init_progress_bar: + for i in range(self._free_init_num_iters): + # For the first FreeInit iteration, the original latent is used without modification. + # Subsequent iterations apply the noise reinitialization technique. + if i == 0: + initial_noise = latents.detach().clone() + else: + current_diffuse_timestep = ( + self.scheduler.config.num_train_timesteps - 1 + ) # diffuse to t=999 noise level + diffuse_timesteps = torch.full((batch_size,), current_diffuse_timestep).long() + z_T = self.scheduler.add_noise( + original_samples=latents, noise=initial_noise, timesteps=diffuse_timesteps.to(device) + ).to(dtype=torch.float32) + z_rand = randn_tensor( + shape=latent_shape, + generator=self._free_init_generator, + device=device, + dtype=torch.float32, + ) + latents = _freq_mix_3d(z_T, z_rand, LPF=free_init_freq_filter) + latents = latents.to(prompt_embeds.dtype) + + # Coarse-to-Fine Sampling for faster inference (can lead to lower quality) + if self._free_init_use_fast_sampling: + current_num_inference_steps = int(num_inference_steps / self._free_init_num_iters * (i + 1)) + self.scheduler.set_timesteps(current_num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + denoise_args.update({"timesteps": timesteps, "num_inference_steps": current_num_inference_steps}) + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + denoise_args.update({"latents": latents, "num_warmup_steps": num_warmup_steps}) + latents = self._denoise_loop(**denoise_args) + + free_init_progress_bar.update() + + return latents + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + def _retrieve_video_frames(self, latents, output_type, return_dict): + """Helper function to handle latents to output conversion.""" + if output_type == "latent": + return PIAPipelineOutput(frames=latents) + + video_tensor = self.decode_latents(latents) + video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + + if not return_dict: + return (video,) + + return PIAPipelineOutput(frames=video) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: Union[str, List[str]] = None, + strength: float = 1.0, + num_frames: Optional[int] = 16, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + mask_sim_template_idx: int = 0, + cond_frame: int = 0, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated video. + num_frames (`int`, *optional*, defaults to 16): + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + amounts to 2 seconds of video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or + `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead + of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. + + Examples: + + Returns: + [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None: + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_videos_per_prompt, output_hidden_state + ) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + 4, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents=latents + ) + mask, masked_image = self.prepare_masked_condition( + image, + batch_size * num_videos_per_prompt, + 4, + num_frames=num_frames, + height=height, + width=width, + dtype=self.unet.dtype, + device=device, + generator=generator, + latent_timestep=latent_timestep, + strength=strength, + mask_sim_template_idx=mask_sim_template_idx, + cond_frame=cond_frame, + ) + if strength < 1.0: + noise = torch.randn_like(latents) + latents = self.scheduler.add_noise(masked_image[0], noise, latent_timestep) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + denoise_args = { + "timesteps": timesteps, + "num_inference_steps": num_inference_steps, + "do_classifier_free_guidance": self.do_classifier_free_guidance, + "guidance_scale": guidance_scale, + "num_warmup_steps": num_warmup_steps, + "prompt_embeds": prompt_embeds, + "negative_prompt_embeds": negative_prompt_embeds, + "latents": latents, + "mask": mask, + "masked_image": masked_image, + "cross_attention_kwargs": self.cross_attention_kwargs, + "added_cond_kwargs": added_cond_kwargs, + "extra_step_kwargs": extra_step_kwargs, + "callback": callback, + "callback_steps": callback_steps, + "callback_on_step_end": callback_on_step_end, + "callback_on_step_end_tensor_inputs": callback_on_step_end_tensor_inputs, + } + + if self.free_init_enabled: + latents = self._free_init_loop( + height=height, + width=width, + num_frames=num_frames, + num_channels_latents=num_channels_latents, + batch_size=batch_size, + num_videos_per_prompt=num_videos_per_prompt, + denoise_args=denoise_args, + device=device, + ) + else: + latents = self._denoise_loop(**denoise_args) + + video = self._retrieve_video_frames(latents, output_type, return_dict) + + # 9. Offload all models + self.maybe_free_model_hooks() + + return video From 5025fc35943f051aeeb3accb8733146c1b4197f4 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 24 Jan 2024 15:16:20 +0000 Subject: [PATCH 04/18] add tests and docs --- docs/source/en/api/pipelines/pia.md | 191 +++++++++++++++++ tests/pipelines/pia/__init__.py | 0 tests/pipelines/pia/test_pia.py | 314 ++++++++++++++++++++++++++++ 3 files changed, 505 insertions(+) create mode 100644 docs/source/en/api/pipelines/pia.md create mode 100644 tests/pipelines/pia/__init__.py create mode 100644 tests/pipelines/pia/test_pia.py diff --git a/docs/source/en/api/pipelines/pia.md b/docs/source/en/api/pipelines/pia.md new file mode 100644 index 000000000000..c30296c74744 --- /dev/null +++ b/docs/source/en/api/pipelines/pia.md @@ -0,0 +1,191 @@ + + +# Image-to-Video Generation with PIA (Personalized Image Animator) + +## Overview + +Recent advancements in personalized text-to-image (T2I) models have revolutionized content creation, empowering non-experts to generate stunning images with unique styles. While promising, adding realistic motions into these personalized images by text poses significant challenges in preserving distinct styles, high-fidelity details, and achieving motion controllability by text. In this paper, we present PIA, a Personalized Image Animator that excels in aligning with condition images, achieving motion controllability by text, and the compatibility with various personalized T2I models without specific tuning. To achieve these goals, PIA builds upon a base T2I model with well-trained temporal alignment layers, allowing for the seamless transformation of any personalized T2I model into an image animation model. A key component of PIA is the introduction of the condition module, which utilizes the condition frame and inter-frame affinity as input to transfer appearance information guided by the affinity hint for individual frame synthesis in the latent space. This design mitigates the challenges of appearance-related image alignment within and allows for a stronger focus on aligning with motion-related guidance. + +## Available Pipelines + +| Pipeline | Tasks | Demo +|---|---|:---:| +| [PIAPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pia/pipeline_pia.py) | *Image-to-Video Generation with PIA* | + +## Available checkpoints + +Motion Adapter checkpoints for PIA can be found under [](). These checkpoints are meant to work with any model based on Stable Diffusion 1.5 + +## Usage example + +PIA works with a MotionAdapter checkpoint and a Stable Diffusion 1.5 model checkpoint. The MotionAdapter is a collection of Motion Modules that are responsible for adding coherent motion across image frames. These modules are applied after the Resnet and Attention blocks in the Stable Diffusion UNet. In addition to the motion modules, PIA also replaces the input convolution layer of the SD 1.5 UNet model with a 9 channel input convolution layer. + +The following example demonstrates how to use PIA to generate a video from a single image. + +```python +import torch +from diffusers import ( + DDIMScheduler, + MotionAdapter, + PIAPipeline, +) +from diffusers.utils import export_to_gif, load_image + +adapter = MotionAdapter.from_pretrained("diffusers/PIA-motion-adapter") +pipe = PIAPipeline.from_pretrained("SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter) +pipe.scheduler = DDIMScheduler( + clip_sample=False, + steps_offset=1, + timestep_spacing="leading", + beta_schedule="linear", + beta_start=0.00085, + beta_end=0.012, +) +pipe.enable_model_cpu_offload() +pipe.enable_vae_slicing() + +image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true" +) +image = image.resize((512, 512)) +prompt = "cat in a field" +negative_prompt = "wrong white balance, dark, sketches,worst quality,low quality" + +generator = torch.Generator("cpu").manual_seed(0) +output = pipe( + image=image, + prompt=prompt, + strength=0.75, + guidance_scale=7.5, + num_inference_steps=25, + motion_scale=2, + generator=generator, +) +frames = output.frames[0] +export_to_gif(frames, "pia-animation.gif") +``` + +Here are some sample outputs: + + + + + +
+ masterpiece, bestquality, sunset. +
+ cat in a field +
+ + + + +If you plan on using a scheduler that can clip samples, make sure to disable it by setting `clip_sample=False` in the scheduler as this can also have an adverse effect on generated samples. Additionally, the PIA checkpoints can be sensitive to the beta schedule of the scheduler. We recommend setting this to `linear`. + + + +## Using FreeInit + +[FreeInit: Bridging Initialization Gap in Video Diffusion Models](https://arxiv.org/abs/2312.07537) by Tianxing Wu, Chenyang Si, Yuming Jiang, Ziqi Huang, Ziwei Liu. + +FreeInit is an effective method that improves temporal consistency and overall quality of videos generated using video-diffusion-models without any addition training. It can be applied to PIA, AnimateDiff, ModelScope, VideoCrafter and various other video generation models seamlessly at inference time, and works by iteratively refining the latent-initialization noise. More details can be found it the paper. + +The following example demonstrates the usage of FreeInit. + +```python +import torch +from diffusers import ( + DDIMScheduler, + MotionAdapter, + PIAPipeline, +) +from diffusers.utils import export_to_gif, load_image + +adapter = MotionAdapter.from_pretrained("../checkpoints/pia-diffusers") +pipe = PIAPipeline.from_pretrained("SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter) + +# enable FreeInit +# Refer to the enable_free_init documentation for a full list of configurable parameters +pipe.enable_free_init(method="butterworth", use_fast_sampling=True) + +# Memory saving options +pipe.enable_model_cpu_offload() +pipe.enable_vae_slicing() + +pipe.scheduler = DDIMScheduler( + clip_sample=False, + steps_offset=1, + timestep_spacing="leading", + beta_schedule="linear", + beta_start=0.00085, + beta_end=0.012, +) +image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true" +) +image = image.resize((512, 512)) +prompt = "cat in a hat" +negative_prompt = "wrong white balance, dark, sketches,worst quality,low quality" + +generator = torch.Generator("cpu").manual_seed(0) + +output = pipe( + image=image, + prompt=prompt, + guidance_scale=7.5, + num_inference_steps=25, + motion_scale=0, + generator=generator, +) +frames = output.frames[0] +export_to_gif(frames, "pia-freeinit-animation.gif") +``` + + + + + +
+ masterpiece, bestquality, sunset. +
+ cat in a field +
+ + + + +FreeInit is not really free - the improved quality comes at the cost of extra computation. It requires sampling a few extra times depending on the `num_iters` parameter that is set when enabling it. Setting the `use_fast_sampling` parameter to `True` can improve the overall performance (at the cost of lower quality compared to when `use_fast_sampling=False` but still better results than vanilla video generation models). + + + +## PIAPipeline + +[[autodoc]] PIAPipeline + - all + - __call__ + - enable_freeu + - disable_freeu + - enable_free_init + - disable_free_init + - enable_vae_slicing + - disable_vae_slicing + - enable_vae_tiling + - disable_vae_tiling + +## PIAPipelineOutput + +[[autodoc]] pipelines.pia.PIAPipelineOutput \ No newline at end of file diff --git a/tests/pipelines/pia/__init__.py b/tests/pipelines/pia/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/pia/test_pia.py b/tests/pipelines/pia/test_pia.py new file mode 100644 index 000000000000..3107f6a39d72 --- /dev/null +++ b/tests/pipelines/pia/test_pia.py @@ -0,0 +1,314 @@ +import random +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +import diffusers +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + MotionAdapter, + PIAPipeline, + UNet2DConditionModel, + UNetMotionModel, +) +from diffusers.utils import is_xformers_available, logging +from diffusers.utils.testing_utils import floats_tensor, torch_device + +from ..test_pipelines_common import PipelineTesterMixin + + +def to_np(tensor): + if isinstance(tensor, torch.Tensor): + tensor = tensor.detach().cpu().numpy() + + return tensor + + +class PIAPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = PIAPipeline + params = frozenset( + [ + "prompt", + "height", + "width", + "guidance_scale", + "negative_prompt", + "prompt_embeds", + "negative_prompt_embeds", + "cross_attention_kwargs", + ] + ) + batch_params = frozenset(["prompt", "image", "generator"]) + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + norm_num_groups=2, + ) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="linear", + clip_sample=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + motion_adapter = MotionAdapter( + block_out_channels=(32, 64), + motion_layers_per_block=2, + motion_norm_num_groups=2, + motion_num_attention_heads=4, + use_input_conv=True, + input_conv_channels=9, + ) + + components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "motion_adapter": motion_adapter, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "feature_extractor": None, + "image_encoder": None, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + inputs = { + "image": image, + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 7.5, + "output_type": "pt", + } + return inputs + + def test_motion_unet_loading(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + + assert isinstance(pipe.unet, UNetMotionModel) + + @unittest.skip("Attention slicing is not enabled in this pipeline") + def test_attention_slicing_forward_pass(self): + pass + + def test_inference_batch_single_identical( + self, + batch_size=2, + expected_max_diff=1e-4, + additional_params_copy_to_batched_inputs=["num_inference_steps"], + ): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for components in pipe.components.values(): + if hasattr(components, "set_default_attn_processor"): + components.set_default_attn_processor() + + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + inputs = self.get_dummy_inputs(torch_device) + # Reset generator in case it is has been used in self.get_dummy_inputs + inputs["generator"] = self.get_generator(0) + + logger = logging.get_logger(pipe.__module__) + logger.setLevel(level=diffusers.logging.FATAL) + + # batchify inputs + batched_inputs = {} + batched_inputs.update(inputs) + + for name in self.batch_params: + if name not in inputs: + continue + + value = inputs[name] + if name == "prompt": + len_prompt = len(value) + batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)] + batched_inputs[name][-1] = 100 * "very long" + + else: + batched_inputs[name] = batch_size * [value] + + if "generator" in inputs: + batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)] + + if "batch_size" in inputs: + batched_inputs["batch_size"] = batch_size + + for arg in additional_params_copy_to_batched_inputs: + batched_inputs[arg] = inputs[arg] + + output = pipe(**inputs) + output_batch = pipe(**batched_inputs) + + assert output_batch[0].shape[0] == batch_size + + max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max() + assert max_diff < expected_max_diff + + @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices") + def test_to_device(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + pipe.to("cpu") + # pipeline creates a new motion UNet under the hood. So we need to check the device from pipe.components + model_devices = [ + component.device.type for component in pipe.components.values() if hasattr(component, "device") + ] + self.assertTrue(all(device == "cpu" for device in model_devices)) + + output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0] + self.assertTrue(np.isnan(output_cpu).sum() == 0) + + pipe.to("cuda") + model_devices = [ + component.device.type for component in pipe.components.values() if hasattr(component, "device") + ] + self.assertTrue(all(device == "cuda" for device in model_devices)) + + output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0] + self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0) + + def test_to_dtype(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + # pipeline creates a new motion UNet under the hood. So we need to check the dtype from pipe.components + model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] + self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) + + pipe.to(torch_dtype=torch.float16) + model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] + self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) + + def test_prompt_embeds(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + inputs = self.get_dummy_inputs(torch_device) + inputs.pop("prompt") + inputs["prompt_embeds"] = torch.randn((1, 4, 32), device=torch_device) + pipe(**inputs) + + def test_free_init(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + inputs_normal = self.get_dummy_inputs(torch_device) + frames_normal = pipe(**inputs_normal).frames[0] + + free_init_generator = torch.Generator(device=torch_device).manual_seed(0) + pipe.enable_free_init( + num_iters=2, + use_fast_sampling=True, + method="butterworth", + order=4, + spatial_stop_frequency=0.25, + temporal_stop_frequency=0.25, + generator=free_init_generator, + ) + inputs_enable_free_init = self.get_dummy_inputs(torch_device) + frames_enable_free_init = pipe(**inputs_enable_free_init).frames[0] + + pipe.disable_free_init() + inputs_disable_free_init = self.get_dummy_inputs(torch_device) + frames_disable_free_init = pipe(**inputs_disable_free_init).frames[0] + + sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum() + max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_init)).max() + self.assertGreater( + sum_enabled, 1e1, "Enabling of FreeInit should lead to results different from the default pipeline results" + ) + self.assertLess( + max_diff_disabled, + 1e-4, + "Disabling of FreeInit should lead to results similar to the default pipeline results", + ) + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output_without_offload = pipe(**inputs).frames[0] + output_without_offload = ( + output_without_offload.cpu() if torch.is_tensor(output_without_offload) else output_without_offload + ) + + pipe.enable_xformers_memory_efficient_attention() + inputs = self.get_dummy_inputs(torch_device) + output_with_offload = pipe(**inputs).frames[0] + output_with_offload = ( + output_with_offload.cpu() if torch.is_tensor(output_with_offload) else output_without_offload + ) + + max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max() + self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results") From 845d7f722d0a2a7fac93f7afc3475f09a893f94f Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 24 Jan 2024 15:24:21 +0000 Subject: [PATCH 05/18] clean up --- src/diffusers/__init__.py | 2 + .../models/unets/unet_motion_model.py | 10 +- src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/pia/__init__.py | 46 +++++++++ src/diffusers/pipelines/pia/pipeline_pia.py | 98 +++++++++++-------- 5 files changed, 116 insertions(+), 42 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 107547abf2f5..7deeeb6f89a4 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -247,6 +247,7 @@ "LDMTextToImagePipeline", "MusicLDMPipeline", "PaintByExamplePipeline", + "PIAPipeline", "PixArtAlphaPipeline", "SemanticStableDiffusionPipeline", "ShapEImg2ImgPipeline", @@ -606,6 +607,7 @@ LDMTextToImagePipeline, MusicLDMPipeline, PaintByExamplePipeline, + PIAPipeline, PixArtAlphaPipeline, SemanticStableDiffusionPipeline, ShapEImg2ImgPipeline, diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index dd0ac0f89fce..be2bf4dcdc95 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -426,16 +426,24 @@ def from_unet2d( if not config.get("num_attention_heads"): config["num_attention_heads"] = config["attention_head_dim"] + if motion_adapter.config["use_input_conv"]: + config["in_channels"] = motion_adapter.config["input_conv_channels"] + model = cls.from_config(config) if not load_weights: return model - if motion_adapter.config["use_input_conv"]: + # The check for the unet.conv_in weight shape is meant to handle cases where + # `save_pretrained` was used with the PIAPipeline, resulting in the PIA weights + # being fused UNet conv_in weights. In such a case we would load the fused weights + # directly into model.conv_in + if motion_adapter.config["use_input_conv"] and unet.conv_in.weight.shape[1] == 4: model.conv_in = motion_adapter.conv_in updated_conv_in_weight = torch.cat( [unet.conv_in.weight, motion_adapter.conv_in.weight[:, 4:, :, :]], dim=1 ) + model.conv_in.load_state_dict({"weight": updated_conv_in_weight, "bias": unet.conv_in.bias}) else: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 13163f96e059..f2304ae7a914 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -168,6 +168,7 @@ _import_structure["latent_diffusion"].extend(["LDMTextToImagePipeline"]) _import_structure["musicldm"] = ["MusicLDMPipeline"] _import_structure["paint_by_example"] = ["PaintByExamplePipeline"] + _import_structure["pia"] = ["PIAPipeline"] _import_structure["pixart_alpha"] = ["PixArtAlphaPipeline"] _import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"] _import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"] @@ -412,6 +413,7 @@ from .latent_diffusion import LDMTextToImagePipeline from .musicldm import MusicLDMPipeline from .paint_by_example import PaintByExamplePipeline + from .pia import PIAPipeline from .pixart_alpha import PixArtAlphaPipeline from .semantic_stable_diffusion import SemanticStableDiffusionPipeline from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline diff --git a/src/diffusers/pipelines/pia/__init__.py b/src/diffusers/pipelines/pia/__init__.py index e69de29bb2d1..16e8004966e5 100644 --- a/src/diffusers/pipelines/pia/__init__.py +++ b/src/diffusers/pipelines/pia/__init__.py @@ -0,0 +1,46 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_pia"] = ["PIAPipeline", "PIAPipelineOutput"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + + else: + from .pipeline_pia import PIAPipeline, PIAPipelineOutput + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index 0caedd8e67be..f5b39edbb58e 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -54,15 +54,39 @@ Examples: ```py >>> import torch - >>> from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler - >>> from diffusers.utils import export_to_gif - - >>> adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2") - >>> pipe = AnimateDiffPipeline.from_pretrained("frankjoshua/toonyou_beta6", motion_adapter=adapter) - >>> pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False) - >>> output = pipe(prompt="A corgi walking in the park") + >>> from diffusers import ( + ... DDIMScheduler, + ... MotionAdapter, + ... PIAPipeline, + ... ) + >>> from diffusers.utils import export_to_gif, load_image + >>> adapter = MotionAdapter.from_pretrained("../checkpoints/pia-diffusers") + >>> pipe = PIAPipeline.from_pretrained("SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter) + >>> pipe.scheduler = DDIMScheduler( + ... clip_sample=False, + ... steps_offset=1, + ... timestep_spacing="leading", + ... beta_schedule="linear", + ... beta_start=0.00085, + ... beta_end=0.012, + ... ) + >>> image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true" + ... ) + >>> image = image.resize((512, 512)) + >>> prompt = "cat in a hat" + >>> negative_prompt = "wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg" + >>> generator = torch.Generator("cpu").manual_seed(0) + >>> output = pipe( + ... image=image, + ... prompt=prompt, + ... guidance_scale=7.5, + ... num_inference_steps=25, + ... motion_scale=0, + ... generator=generator, + ... ) >>> frames = output.frames[0] - >>> export_to_gif(frames, "animation.gif") + >>> export_to_gif(frames, "pia-animation.gif") ``` """ @@ -73,12 +97,13 @@ [1.0, 0.9, 0.85, 0.85, 0.85, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.85, 0.85, 0.9, 1.0], # Loop [1.0, 0.8, 0.8, 0.8, 0.79, 0.78, 0.75, 0.75, 0.75, 0.75, 0.75, 0.78, 0.79, 0.8, 0.8, 1.0], # Loop [1.0, 0.8, 0.7, 0.7, 0.7, 0.7, 0.6, 0.5, 0.5, 0.6, 0.7, 0.7, 0.7, 0.7, 0.8, 1.0], # Loop - [0.5, 0.2], # Style Transfer Large Motion - [0.5, 0.4, 0.4, 0.4, 0.35, 0.35, 0.3, 0.25, 0.2], # Style Transfer Moderate Motion [0.5, 0.4, 0.4, 0.4, 0.35, 0.3], # Style Transfer Candidate Small Motion + [0.5, 0.4, 0.4, 0.4, 0.35, 0.35, 0.3, 0.25, 0.2], # Style Transfer Moderate Motion + [0.5, 0.2], # Style Transfer Large Motion ] +# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): batch_size, channels, num_frames, height, width = video.shape outputs = [] @@ -100,16 +125,16 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: return outputs -def prepare_mask_coef_by_statistics(num_frames: int, cond_frame: int, sim_range: int): +def prepare_mask_coef_by_statistics(num_frames: int, cond_frame: int, motion_scale: int): assert num_frames > 0, "video_length should be greater than 0" assert num_frames > cond_frame, "video_length should be greater than cond_frame" range_list = RANGE_LIST - assert sim_range < len(range_list), f"sim_range type{sim_range} not implemented" + assert motion_scale < len(range_list), f"motion_scale type{motion_scale} not implemented" - coef = range_list[sim_range] + coef = range_list[motion_scale] coef = coef + ([coef[-1]] * (num_frames - len(coef))) order = [abs(i - cond_frame) for i in range(num_frames)] @@ -717,12 +742,8 @@ def prepare_masked_condition( dtype, device, generator, - latent_timestep, - strength=1.0, - mask_sim_template_idx=0, - cond_frame=0, + motion_scale=0, ): - shape = ( batch_size, num_channels_latents, @@ -745,10 +766,10 @@ def prepare_masked_condition( image_latent = image_latent.to(device=device, dtype=dtype) image_latent = torch.nn.functional.interpolate(image_latent, size=[scaled_height, scaled_width]) - image_latent_padding = image_latent.clone() * 0.18215 + image_latent_padding = image_latent.clone() * self.vae.config.scaling_factor mask = torch.zeros((batch_size, 1, num_frames, scaled_height, scaled_width)).to(device=device, dtype=dtype) - mask_coef = prepare_mask_coef_by_statistics(num_frames, cond_frame, mask_sim_template_idx) + mask_coef = prepare_mask_coef_by_statistics(num_frames, 0, motion_scale) masked_image = torch.zeros(batch_size, 4, num_frames, scaled_height, scaled_width).to( device=device, dtype=self.unet.dtype ) @@ -759,13 +780,6 @@ def prepare_masked_condition( mask = torch.cat([mask] * 2) if self.do_classifier_free_guidance else mask masked_image = torch.cat([masked_image] * 2) if self.do_classifier_free_guidance else masked_image - """ - if strength < 1.0: - noise = torch.randn_like(latents) - latents = self.scheduler.add_noise(masked_image[0], noise, latent_timestep) - """ - #latents = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - return mask, masked_image def _denoise_loop( @@ -788,7 +802,7 @@ def _denoise_loop( callback_on_step_end, callback_on_step_end_tensor_inputs, ): - """Denoising loop for AnimateDiff.""" + """Denoising loop for PIA.""" with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance @@ -836,13 +850,12 @@ def _free_init_loop( height, width, num_frames, - num_channels_latents, batch_size, num_videos_per_prompt, denoise_args, device, ): - """Denoising loop for AnimateDiff using FreeInit noise reinitialization technique.""" + """Denoising loop for PIA using FreeInit noise reinitialization technique.""" latents = denoise_args.get("latents") prompt_embeds = denoise_args.get("prompt_embeds") @@ -851,14 +864,14 @@ def _free_init_loop( latent_shape = ( batch_size * num_videos_per_prompt, - num_channels_latents, + 4, num_frames, height // self.vae_scale_factor, width // self.vae_scale_factor, ) free_init_filter_shape = ( 1, - num_channels_latents, + 4, num_frames, height // self.vae_scale_factor, width // self.vae_scale_factor, @@ -975,8 +988,7 @@ def __call__( prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, - mask_sim_template_idx: int = 0, - cond_frame: int = 0, + motion_scale: int = 0, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -989,8 +1001,11 @@ def __call__( The call function to the pipeline for generation. Args: + image (`PipelineImageInput`): + The input image to be used for video generation. prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + strength (`float`, *optional*, defaults to 1.0): Indicates extent to transform the reference `image`. Must be between 0 and 1. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated video. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): @@ -1026,6 +1041,12 @@ def __call__( not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + motion_scale: (`int`, *optional*, defaults to 0): + Parameter that controls the amount and type of motion that is added to the image. Increasing the value increases the amount of motion, while specific + ranges of values control the type of motion that is added. Must be between 0 and 8. + Set between 0-2 to only increase the amount of motion. + Set between 3-5 to create looping motion. + Set between 6-8 to perform motion with image style transfer. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or `np.array`. @@ -1140,7 +1161,6 @@ def __call__( self._num_timesteps = len(timesteps) # 5. Prepare latent variables - num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, 4, @@ -1150,7 +1170,7 @@ def __call__( prompt_embeds.dtype, device, generator, - latents=latents + latents=latents, ) mask, masked_image = self.prepare_masked_condition( image, @@ -1162,10 +1182,7 @@ def __call__( dtype=self.unet.dtype, device=device, generator=generator, - latent_timestep=latent_timestep, - strength=strength, - mask_sim_template_idx=mask_sim_template_idx, - cond_frame=cond_frame, + motion_scale=motion_scale, ) if strength < 1.0: noise = torch.randn_like(latents) @@ -1204,7 +1221,6 @@ def __call__( height=height, width=width, num_frames=num_frames, - num_channels_latents=num_channels_latents, batch_size=batch_size, num_videos_per_prompt=num_videos_per_prompt, denoise_args=denoise_args, From 1ba867ddea62ae9e02a42558387f893f032a852c Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 24 Jan 2024 15:30:40 +0000 Subject: [PATCH 06/18] add to toctree --- docs/source/en/_toctree.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index acca4b148c58..adf6731de5a2 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -300,6 +300,8 @@ title: MusicLDM - local: api/pipelines/paint_by_example title: Paint by Example + - local: api/pipelines/pia + title: Personalized Image Animator (PIA) - local: api/pipelines/pixart title: PixArt-α - local: api/pipelines/self_attention_guidance From 8d1794db960f76bba784ee503d4676f235a3224e Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 24 Jan 2024 15:32:46 +0000 Subject: [PATCH 07/18] fix copies --- .../utils/dummy_torch_and_transformers_objects.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 2eb9599658d9..751548d481c9 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -647,6 +647,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class PIAPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class PixArtAlphaPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 1c7b31a580d042a07bcacdb9131d050ec09819f1 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 25 Jan 2024 13:12:50 +0000 Subject: [PATCH 08/18] pr review feedback --- .../models/unets/unet_motion_model.py | 27 +++++++------------ src/diffusers/pipelines/pia/pipeline_pia.py | 19 ++++++------- 2 files changed, 18 insertions(+), 28 deletions(-) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index be2bf4dcdc95..5032ec41b71b 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -89,8 +89,7 @@ def __init__( motion_norm_num_groups: int = 32, motion_max_seq_length: int = 32, use_motion_mid_block: bool = True, - use_input_conv: bool = False, - input_conv_channels: int = 9, + conv_in_channels: Optional[int] = None, ): """Container to store AnimateDiff Motion Modules @@ -115,12 +114,12 @@ def __init__( down_blocks = [] up_blocks = [] - if use_input_conv: + if conv_in_channels: # input conv_in_kernel = 3 conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( - input_conv_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + conv_in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) else: self.conv_in = None @@ -422,30 +421,27 @@ def from_unet2d( config["motion_max_seq_length"] = motion_adapter.config["motion_max_seq_length"] config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"] + # For PIA UNets we need to set the number input channels to 9 + if motion_adapter.config["conv_in_channels"]: + config["in_channels"] = motion_adapter.config["conv_in_channels"] + # Need this for backwards compatibility with UNet2DConditionModel checkpoints if not config.get("num_attention_heads"): config["num_attention_heads"] = config["attention_head_dim"] - if motion_adapter.config["use_input_conv"]: - config["in_channels"] = motion_adapter.config["input_conv_channels"] - model = cls.from_config(config) if not load_weights: return model - # The check for the unet.conv_in weight shape is meant to handle cases where - # `save_pretrained` was used with the PIAPipeline, resulting in the PIA weights - # being fused UNet conv_in weights. In such a case we would load the fused weights - # directly into model.conv_in - if motion_adapter.config["use_input_conv"] and unet.conv_in.weight.shape[1] == 4: + # Logic for loading PIA UNets which allow the first 4 channels to be any UNet2DConditionModel conv_in weight + # while the last 5 channels must be PIA conv_in weights. + if has_motion_adapter and motion_adapter.config["conv_in_channels"]: model.conv_in = motion_adapter.conv_in updated_conv_in_weight = torch.cat( [unet.conv_in.weight, motion_adapter.conv_in.weight[:, 4:, :, :]], dim=1 ) - model.conv_in.load_state_dict({"weight": updated_conv_in_weight, "bias": unet.conv_in.bias}) - else: model.conv_in.load_state_dict(unet.conv_in.state_dict()) @@ -517,9 +513,6 @@ def load_motion_modules(self, motion_adapter: Optional[MotionAdapter]) -> None: if hasattr(self.mid_block, "motion_modules"): self.mid_block.motion_modules.load_state_dict(motion_adapter.mid_block.motion_modules.state_dict()) - if motion_adapter.config["use_input_conv"]: - self.conv_in.load_state_dict(motion_adapter.conv_in.state_dict()) - def save_motion_modules( self, save_directory: str, diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index f5b39edbb58e..b8bf6a03368c 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -107,6 +107,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): batch_size, channels, num_frames, height, width = video.shape outputs = [] + for batch_idx in range(batch_size): batch_vid = video[batch_idx].permute(1, 0, 2, 3) batch_output = processor.postprocess(batch_vid, output_type) @@ -243,7 +244,7 @@ class PIAPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin """ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" - _optional_components = ["feature_extractor", "image_encoder"] + _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( @@ -251,8 +252,7 @@ def __init__( vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - motion_adapter: MotionAdapter, + unet: Union[UNet2DConditionModel, UNetMotionModel], scheduler: Union[ DDIMScheduler, PNDMScheduler, @@ -261,11 +261,13 @@ def __init__( EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, ], + motion_adapter: Optional[MotionAdapter] = None, feature_extractor: CLIPImageProcessor = None, image_encoder: CLIPVisionModelWithProjection = None, ): super().__init__() - unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + if isinstance(unet, UNet2DConditionModel): + unet = UNetMotionModel.from_unet2d(unet, motion_adapter) self.register_modules( vae=vae, @@ -695,6 +697,7 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) + # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents def prepare_latents( self, batch_size, @@ -715,7 +718,6 @@ def prepare_latents( width // self.vae_scale_factor, ) - _, _, _, scaled_height, scaled_width = shape if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -797,8 +799,6 @@ def _denoise_loop( cross_attention_kwargs, added_cond_kwargs, extra_step_kwargs, - callback, - callback_steps, callback_on_step_end, callback_on_step_end_tensor_inputs, ): @@ -840,8 +840,6 @@ def _denoise_loop( # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) return latents @@ -923,6 +921,7 @@ def _free_init_loop( return latents + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) @@ -1210,8 +1209,6 @@ def __call__( "cross_attention_kwargs": self.cross_attention_kwargs, "added_cond_kwargs": added_cond_kwargs, "extra_step_kwargs": extra_step_kwargs, - "callback": callback, - "callback_steps": callback_steps, "callback_on_step_end": callback_on_step_end, "callback_on_step_end_tensor_inputs": callback_on_step_end_tensor_inputs, } From a93a5ba95cce4079782dc9a2cae47e55362eaf7a Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 25 Jan 2024 13:27:07 +0000 Subject: [PATCH 09/18] fix copies --- src/diffusers/pipelines/pia/pipeline_pia.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index b8bf6a03368c..576b1a66fdc0 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -107,7 +107,6 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): batch_size, channels, num_frames, height, width = video.shape outputs = [] - for batch_idx in range(batch_size): batch_vid = video[batch_idx].permute(1, 0, 2, 3) batch_output = processor.postprocess(batch_vid, output_type) @@ -699,16 +698,7 @@ def check_inputs( # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents def prepare_latents( - self, - batch_size, - num_channels_latents, - num_frames, - height, - width, - dtype, - device, - generator, - latents=None, + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): shape = ( batch_size, @@ -717,7 +707,6 @@ def prepare_latents( height // self.vae_scale_factor, width // self.vae_scale_factor, ) - if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -729,8 +718,8 @@ def prepare_latents( else: latents = latents.to(device) + # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma - return latents def prepare_masked_condition( @@ -927,7 +916,7 @@ def get_timesteps(self, num_inference_steps, strength, device): init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start:] + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start From 4c9244dd3be0c98ae39eebe2b9721ef0cd4d0ea6 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 25 Jan 2024 13:33:25 +0000 Subject: [PATCH 10/18] fix tests --- tests/pipelines/pia/test_pia.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/pipelines/pia/test_pia.py b/tests/pipelines/pia/test_pia.py index 3107f6a39d72..eb76457abc9d 100644 --- a/tests/pipelines/pia/test_pia.py +++ b/tests/pipelines/pia/test_pia.py @@ -100,8 +100,7 @@ def get_dummy_components(self): motion_layers_per_block=2, motion_norm_num_groups=2, motion_num_attention_heads=4, - use_input_conv=True, - input_conv_channels=9, + conv_in_channels=9, ) components = { From b3e4161e57c9f111b2115fbac8c6379cd31c40ed Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 30 Jan 2024 11:57:05 +0000 Subject: [PATCH 11/18] update docs --- docs/source/en/api/pipelines/pia.md | 27 ++++++--------------------- 1 file changed, 6 insertions(+), 21 deletions(-) diff --git a/docs/source/en/api/pipelines/pia.md b/docs/source/en/api/pipelines/pia.md index c30296c74744..8b50f8013961 100644 --- a/docs/source/en/api/pipelines/pia.md +++ b/docs/source/en/api/pipelines/pia.md @@ -24,7 +24,7 @@ Recent advancements in personalized text-to-image (T2I) models have revolutioniz ## Available checkpoints -Motion Adapter checkpoints for PIA can be found under [](). These checkpoints are meant to work with any model based on Stable Diffusion 1.5 +Motion Adapter checkpoints for PIA can be found under the [OpenMMLab org](https://huggingface.co/openmmlab/PIA-condition-adapter). These checkpoints are meant to work with any model based on Stable Diffusion 1.5 ## Usage example @@ -41,8 +41,8 @@ from diffusers import ( ) from diffusers.utils import export_to_gif, load_image -adapter = MotionAdapter.from_pretrained("diffusers/PIA-motion-adapter") -pipe = PIAPipeline.from_pretrained("SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter) +adapter = MotionAdapter.from_pretrained("openmmlab/PIA-condition-adapter") +pipe = PIAPipeline.from_pretrained("SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter, torch_dtype=torch.float16) pipe.scheduler = DDIMScheduler( clip_sample=False, steps_offset=1, @@ -62,15 +62,7 @@ prompt = "cat in a field" negative_prompt = "wrong white balance, dark, sketches,worst quality,low quality" generator = torch.Generator("cpu").manual_seed(0) -output = pipe( - image=image, - prompt=prompt, - strength=0.75, - guidance_scale=7.5, - num_inference_steps=25, - motion_scale=2, - generator=generator, -) +output = pipe(image=image, prompt=prompt, generator=generator) frames = output.frames[0] export_to_gif(frames, "pia-animation.gif") ``` @@ -113,7 +105,7 @@ from diffusers import ( ) from diffusers.utils import export_to_gif, load_image -adapter = MotionAdapter.from_pretrained("../checkpoints/pia-diffusers") +adapter = MotionAdapter.from_pretrained("openmmlab/PIA-condition-adapter") pipe = PIAPipeline.from_pretrained("SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter) # enable FreeInit @@ -141,14 +133,7 @@ negative_prompt = "wrong white balance, dark, sketches,worst quality,low quality generator = torch.Generator("cpu").manual_seed(0) -output = pipe( - image=image, - prompt=prompt, - guidance_scale=7.5, - num_inference_steps=25, - motion_scale=0, - generator=generator, -) +output = pipe(image=image, prompt=prompt, generator=generator) frames = output.frames[0] export_to_gif(frames, "pia-freeinit-animation.gif") ``` From e4535a52b2d003f16323dfa3130ef14681ec4d77 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 31 Jan 2024 06:36:33 +0000 Subject: [PATCH 12/18] update --- docs/source/en/api/pipelines/pia.md | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/docs/source/en/api/pipelines/pia.md b/docs/source/en/api/pipelines/pia.md index 8b50f8013961..52ed8d977ba7 100644 --- a/docs/source/en/api/pipelines/pia.md +++ b/docs/source/en/api/pipelines/pia.md @@ -35,7 +35,7 @@ The following example demonstrates how to use PIA to generate a video from a sin ```python import torch from diffusers import ( - DDIMScheduler, + EulerDiscreteScheduler, MotionAdapter, PIAPipeline, ) @@ -43,11 +43,9 @@ from diffusers.utils import export_to_gif, load_image adapter = MotionAdapter.from_pretrained("openmmlab/PIA-condition-adapter") pipe = PIAPipeline.from_pretrained("SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter, torch_dtype=torch.float16) -pipe.scheduler = DDIMScheduler( - clip_sample=False, + +pipe.scheduler = EulerDiscreteScheduler( steps_offset=1, - timestep_spacing="leading", - beta_schedule="linear", beta_start=0.00085, beta_end=0.012, ) @@ -99,7 +97,7 @@ The following example demonstrates the usage of FreeInit. ```python import torch from diffusers import ( - DDIMScheduler, + EulerDiscreteScheduler, MotionAdapter, PIAPipeline, ) @@ -116,11 +114,8 @@ pipe.enable_free_init(method="butterworth", use_fast_sampling=True) pipe.enable_model_cpu_offload() pipe.enable_vae_slicing() -pipe.scheduler = DDIMScheduler( - clip_sample=False, +pipe.scheduler = EulerDiscreteScheduler( steps_offset=1, - timestep_spacing="leading", - beta_schedule="linear", beta_start=0.00085, beta_end=0.012, ) From 8961802a37f75dc4c9d82946ee71819916a0f3e1 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 31 Jan 2024 07:02:53 +0000 Subject: [PATCH 13/18] update --- src/diffusers/pipelines/pia/pipeline_pia.py | 22 ++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index 576b1a66fdc0..a9a19fbdda31 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -18,6 +18,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np +import PIL import torch import torch.fft as fft from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection @@ -55,18 +56,15 @@ ```py >>> import torch >>> from diffusers import ( - ... DDIMScheduler, + ... EulerDiscreteScheduler, ... MotionAdapter, ... PIAPipeline, ... ) >>> from diffusers.utils import export_to_gif, load_image >>> adapter = MotionAdapter.from_pretrained("../checkpoints/pia-diffusers") >>> pipe = PIAPipeline.from_pretrained("SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter) - >>> pipe.scheduler = DDIMScheduler( - ... clip_sample=False, + >>> pipe.scheduler = EulerDiscreteScheduler( ... steps_offset=1, - ... timestep_spacing="leading", - ... beta_schedule="linear", ... beta_start=0.00085, ... beta_end=0.012, ... ) @@ -210,7 +208,16 @@ def _freq_mix_3d(x: torch.Tensor, noise: torch.Tensor, LPF: torch.Tensor) -> tor @dataclass class PIAPipelineOutput(BaseOutput): - frames: Union[torch.Tensor, np.ndarray] + r""" + Output class for PIAPipeline. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[PIL.Image.Image]): + Nested list of length `batch_size` with denoised PIL image sequences of length `num_frames`, + NumPy array of shape `(batch_size, num_frames, channels, height, width, + Torch tensor of shape `(batch_size, num_frames, channels, height, width)`. + """ + frames: Union[torch.Tensor, np.ndarray, PIL.Image.Image] class PIAPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin): @@ -585,7 +592,7 @@ def enable_free_init( order: int = 4, spatial_stop_frequency: float = 0.25, temporal_stop_frequency: float = 0.25, - generator: torch.Generator = None, + generator: Optional[torch.Generator] = None, ): """Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537. @@ -1110,6 +1117,7 @@ def __call__( batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] + print(batch_size) device = self._execution_device From 7542799344c4a55fa8fc0385f57f1c84e1fd2456 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 31 Jan 2024 08:33:20 +0000 Subject: [PATCH 14/18] update docs --- docs/source/en/api/pipelines/pia.md | 4 ++++ src/diffusers/pipelines/pia/pipeline_pia.py | 1 + 2 files changed, 5 insertions(+) diff --git a/docs/source/en/api/pipelines/pia.md b/docs/source/en/api/pipelines/pia.md index 52ed8d977ba7..a9143eeca3b6 100644 --- a/docs/source/en/api/pipelines/pia.md +++ b/docs/source/en/api/pipelines/pia.md @@ -14,8 +14,12 @@ specific language governing permissions and limitations under the License. ## Overview +[PIA: Your Personalized Image Animator via Plug-and-Play Modules in Text-to-Image Models](https://arxiv.org/abs/2312.13964) by Yiming Zhang, Zhening Xing, Yanhong Zeng, Youqing Fang, Kai Chen + Recent advancements in personalized text-to-image (T2I) models have revolutionized content creation, empowering non-experts to generate stunning images with unique styles. While promising, adding realistic motions into these personalized images by text poses significant challenges in preserving distinct styles, high-fidelity details, and achieving motion controllability by text. In this paper, we present PIA, a Personalized Image Animator that excels in aligning with condition images, achieving motion controllability by text, and the compatibility with various personalized T2I models without specific tuning. To achieve these goals, PIA builds upon a base T2I model with well-trained temporal alignment layers, allowing for the seamless transformation of any personalized T2I model into an image animation model. A key component of PIA is the introduction of the condition module, which utilizes the condition frame and inter-frame affinity as input to transfer appearance information guided by the affinity hint for individual frame synthesis in the latent space. This design mitigates the challenges of appearance-related image alignment within and allows for a stronger focus on aligning with motion-related guidance. +[Project page](https://pi-animator.github.io/) + ## Available Pipelines | Pipeline | Tasks | Demo diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index a9a19fbdda31..7a14d0cc75c5 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -217,6 +217,7 @@ class PIAPipelineOutput(BaseOutput): NumPy array of shape `(batch_size, num_frames, channels, height, width, Torch tensor of shape `(batch_size, num_frames, channels, height, width)`. """ + frames: Union[torch.Tensor, np.ndarray, PIL.Image.Image] From 911b6825c9f0558f9f74933822499ab3bac407c4 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 31 Jan 2024 13:50:55 +0000 Subject: [PATCH 15/18] update --- docs/source/en/api/pipelines/pia.md | 6 +--- src/diffusers/pipelines/pia/pipeline_pia.py | 35 ++------------------- 2 files changed, 3 insertions(+), 38 deletions(-) diff --git a/docs/source/en/api/pipelines/pia.md b/docs/source/en/api/pipelines/pia.md index a9143eeca3b6..3097aac22e19 100644 --- a/docs/source/en/api/pipelines/pia.md +++ b/docs/source/en/api/pipelines/pia.md @@ -48,11 +48,7 @@ from diffusers.utils import export_to_gif, load_image adapter = MotionAdapter.from_pretrained("openmmlab/PIA-condition-adapter") pipe = PIAPipeline.from_pretrained("SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter, torch_dtype=torch.float16) -pipe.scheduler = EulerDiscreteScheduler( - steps_offset=1, - beta_start=0.00085, - beta_end=0.012, -) +pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) pipe.enable_model_cpu_offload() pipe.enable_vae_slicing() diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index 7a14d0cc75c5..3b2fbe159627 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -39,7 +39,6 @@ from ...utils import ( USE_PEFT_BACKEND, BaseOutput, - deprecate, logging, replace_example_docstring, scale_lora_layers, @@ -63,11 +62,7 @@ >>> from diffusers.utils import export_to_gif, load_image >>> adapter = MotionAdapter.from_pretrained("../checkpoints/pia-diffusers") >>> pipe = PIAPipeline.from_pretrained("SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter) - >>> pipe.scheduler = EulerDiscreteScheduler( - ... steps_offset=1, - ... beta_start=0.00085, - ... beta_end=0.012, - ... ) + >>> pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) >>> image = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true" ... ) @@ -657,7 +652,6 @@ def check_inputs( prompt, height, width, - callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, @@ -666,11 +660,6 @@ def check_inputs( if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): @@ -991,7 +980,6 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - **kwargs, ): r""" The call function to the pipeline for generation. @@ -1072,23 +1060,6 @@ def __call__( If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. """ - - callback = kwargs.pop("callback", None) - callback_steps = kwargs.pop("callback_steps", None) - - if callback is not None: - deprecate( - "callback", - "1.0.0", - "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", - ) - if callback_steps is not None: - deprecate( - "callback_steps", - "1.0.0", - "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", - ) - # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor @@ -1100,7 +1071,6 @@ def __call__( prompt, height, width, - callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, @@ -1118,7 +1088,6 @@ def __call__( batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] - print(batch_size) device = self._execution_device @@ -1182,7 +1151,7 @@ def __call__( motion_scale=motion_scale, ) if strength < 1.0: - noise = torch.randn_like(latents) + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) latents = self.scheduler.add_noise(masked_image[0], noise, latent_timestep) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline From 2aba6316cd0fa8d71f9ec62550f3e47f51a34eb9 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 31 Jan 2024 14:14:58 +0000 Subject: [PATCH 16/18] update --- src/diffusers/pipelines/pia/pipeline_pia.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index 3b2fbe159627..3bf2dbfe2b19 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -146,7 +146,7 @@ def _get_freeinit_freq_filter( ) -> torch.Tensor: r"""Returns the FreeInit filter based on filter type and other input conditions.""" - T, H, W = shape[-3], shape[-2], shape[-1] + time, height, width = shape[-3], shape[-2], shape[-1] mask = torch.zeros(shape) if spatial_stop_frequency == 0 or temporal_stop_frequency == 0: @@ -167,13 +167,13 @@ def retrieve_mask(x): else: raise NotImplementedError("`filter_type` must be one of gaussian, butterworth or ideal") - for t in range(T): - for h in range(H): - for w in range(W): + for t in range(time): + for h in range(height): + for w in range(width): d_square = ( - ((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / T - 1)) ** 2 - + (2 * h / H - 1) ** 2 - + (2 * w / W - 1) ** 2 + ((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / time - 1)) ** 2 + + (2 * h / height - 1) ** 2 + + (2 * w / width - 1) ** 2 ) mask[..., t, h, w] = retrieve_mask(d_square) @@ -646,7 +646,6 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, From d01a61bc81d04d4e7ed0c1b09e5701fbe851b8dd Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 31 Jan 2024 14:28:08 +0000 Subject: [PATCH 17/18] update --- docs/source/en/api/pipelines/pia.md | 8 ++------ src/diffusers/models/unets/unet_motion_model.py | 6 +----- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/docs/source/en/api/pipelines/pia.md b/docs/source/en/api/pipelines/pia.md index 3097aac22e19..20ca41a8e44e 100644 --- a/docs/source/en/api/pipelines/pia.md +++ b/docs/source/en/api/pipelines/pia.md @@ -97,7 +97,7 @@ The following example demonstrates the usage of FreeInit. ```python import torch from diffusers import ( - EulerDiscreteScheduler, + DDIMScheduler, MotionAdapter, PIAPipeline, ) @@ -114,11 +114,7 @@ pipe.enable_free_init(method="butterworth", use_fast_sampling=True) pipe.enable_model_cpu_offload() pipe.enable_vae_slicing() -pipe.scheduler = EulerDiscreteScheduler( - steps_offset=1, - beta_start=0.00085, - beta_end=0.012, -) +pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) image = load_image( "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true" ) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 5032ec41b71b..828510794ef3 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -116,11 +116,7 @@ def __init__( if conv_in_channels: # input - conv_in_kernel = 3 - conv_in_padding = (conv_in_kernel - 1) // 2 - self.conv_in = nn.Conv2d( - conv_in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding - ) + self.conv_in = nn.Conv2d(conv_in_channels, block_out_channels[0], kernel_size=3, padding=1) else: self.conv_in = None From facdd91be9b61f944797628e598d12c03b814c2a Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 31 Jan 2024 14:29:55 +0000 Subject: [PATCH 18/18] update --- src/diffusers/pipelines/pia/pipeline_pia.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index 3bf2dbfe2b19..565544a0fef4 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -70,14 +70,7 @@ >>> prompt = "cat in a hat" >>> negative_prompt = "wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg" >>> generator = torch.Generator("cpu").manual_seed(0) - >>> output = pipe( - ... image=image, - ... prompt=prompt, - ... guidance_scale=7.5, - ... num_inference_steps=25, - ... motion_scale=0, - ... generator=generator, - ... ) + >>> output = pipe(image=image, prompt=prompt, negative_prompt=negative_prompt, generator=generator) >>> frames = output.frames[0] >>> export_to_gif(frames, "pia-animation.gif") ```