Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,33 @@
import torch
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection

from ...image_processor import VaeImageProcessor
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
from ...schedulers import EulerDiscreteScheduler
from ...utils import BaseOutput, logging
from ...utils import BaseOutput, logging, replace_example_docstring
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline


logger = logging.get_logger(__name__) # pylint: disable=invalid-name

EXAMPLE_DOC_STRING = """
Examples:
```py
>>> from diffusers import StableVideoDiffusionPipeline
>>> from diffusers.utils import load_image, export_to_video

>>> pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16")
>>> pipe.to("cuda")

>>> image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd-docstring-example.jpeg")
>>> image = image.resize((1024, 576))

>>> frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
>>> export_to_video(frames, "generated.mp4", fps=7)
```
"""


def _append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
Expand All @@ -41,7 +58,7 @@ def _append_dims(x, target_dims):


# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
def tensor2vid(video: torch.Tensor, processor: VaeImageProcessor, output_type: str = "np"):
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
Expand All @@ -65,15 +82,15 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
@dataclass
class StableVideoDiffusionPipelineOutput(BaseOutput):
r"""
Output class for zero-shot text-to-video pipeline.
Output class for Stable Video Diffusion pipeline.

Args:
frames (`[List[PIL.Image.Image]`, `np.ndarray`]):
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
num_channels)`.
frames (`[List[List[PIL.Image.Image]]`, `np.ndarray`, `torch.FloatTensor`]):
List of denoised PIL images of length `batch_size` or numpy array or torch tensor
of shape `(batch_size, num_frames, height, width, num_channels)`.
"""

frames: Union[List[PIL.Image.Image], np.ndarray]
frames: Union[List[List[PIL.Image.Image]], np.ndarray, torch.FloatTensor]


class StableVideoDiffusionPipeline(DiffusionPipeline):
Expand Down Expand Up @@ -119,7 +136,13 @@ def __init__(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)

def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance):
def _encode_image(
self,
image: PipelineImageInput,
device: Union[str, torch.device],
num_videos_per_prompt: int,
do_classifier_free_guidance: bool,
) -> torch.FloatTensor:
dtype = next(self.image_encoder.parameters()).dtype

if not isinstance(image, torch.Tensor):
Expand Down Expand Up @@ -164,9 +187,9 @@ def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free
def _encode_vae_image(
self,
image: torch.Tensor,
device,
num_videos_per_prompt,
do_classifier_free_guidance,
device: Union[str, torch.device],
num_videos_per_prompt: int,
do_classifier_free_guidance: bool,
):
image = image.to(device=device)
image_latents = self.vae.encode(image).latent_dist.mode()
Expand All @@ -186,13 +209,13 @@ def _encode_vae_image(

def _get_add_time_ids(
self,
fps,
motion_bucket_id,
noise_aug_strength,
dtype,
batch_size,
num_videos_per_prompt,
do_classifier_free_guidance,
fps: int,
motion_bucket_id: int,
noise_aug_strength: float,
dtype: torch.dtype,
batch_size: int,
num_videos_per_prompt: int,
do_classifier_free_guidance: bool,
):
add_time_ids = [fps, motion_bucket_id, noise_aug_strength]

Expand All @@ -212,7 +235,7 @@ def _get_add_time_ids(

return add_time_ids

def decode_latents(self, latents, num_frames, decode_chunk_size=14):
def decode_latents(self, latents: torch.FloatTensor, num_frames: int, decode_chunk_size: int = 14):
# [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
latents = latents.flatten(0, 1)

Expand Down Expand Up @@ -257,15 +280,15 @@ def check_inputs(self, image, height, width):

def prepare_latents(
self,
batch_size,
num_frames,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
batch_size: int,
num_frames: int,
num_channels_latents: int,
height: int,
width: int,
dtype: torch.dtype,
device: Union[str, torch.device],
generator: torch.Generator,
latents: Optional[torch.FloatTensor] = None,
):
shape = (
batch_size,
Expand Down Expand Up @@ -307,6 +330,7 @@ def num_timesteps(self):
return self._num_timesteps

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
Expand All @@ -333,15 +357,16 @@ def __call__(

Args:
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
Image or images to guide image generation. If you provide a tensor, the expected value range is between `[0,1]`.
Image(s) to guide image generation. If you provide a tensor, the expected value range is between `[0, 1]`.
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The width in pixels of the generated image.
num_frames (`int`, *optional*):
The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`
The number of video frames to generate. Defaults to `self.unet.config.num_frames`
(14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`).
num_inference_steps (`int`, *optional*, defaults to 25):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
The number of denoising steps. More denoising steps usually lead to a higher quality video at the
expense of slower inference. This parameter is modulated by `strength`.
min_guidance_scale (`float`, *optional*, defaults to 1.0):
The minimum guidance scale. Used for the classifier free guidance with first frame.
Expand All @@ -351,29 +376,29 @@ def __call__(
Frames per second. The rate at which the generated images shall be exported to a video after generation.
Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
motion_bucket_id (`int`, *optional*, defaults to 127):
The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video.
Used for conditioning the amount of motion for the generation. The higher the number the more motion
will be in the video.
noise_aug_strength (`float`, *optional*, defaults to 0.02):
The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion.
decode_chunk_size (`int`, *optional*):
The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency
between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once
for maximal quality. Reduce `decode_chunk_size` to reduce memory usage.
The number of frames to decode at a time. Higher chunk size leads to better temporal consistency at the expense of more memory usage. By default, the decoder decodes all frames at once for maximal
quality. For lower memory usage, reduce `decode_chunk_size`.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
The output format of the generated image. Choose between `pil`, `np` or `pt`.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
A function that is called at the end of each denoising step during inference. The function is called
with the following arguments:
`callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`.
`callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
Expand All @@ -382,26 +407,12 @@ def __call__(
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.

Examples:

Returns:
[`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list of list with the generated frames.

Examples:

```py
from diffusers import StableVideoDiffusionPipeline
from diffusers.utils import load_image, export_to_video

pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16")
pipe.to("cuda")

image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200")
image = image.resize((1024, 576))

frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
export_to_video(frames, "generated.mp4", fps=7)
```
otherwise a `tuple` of (`List[List[PIL.Image.Image]]` or `np.ndarray` or `torch.FloatTensor`) is returned.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
Expand Down Expand Up @@ -429,8 +440,7 @@ def __call__(
# 3. Encode input image
image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)

# NOTE: Stable Diffusion Video was conditioned on fps - 1, which
# is why it is reduced here.
# NOTE: Stable Video Diffusion was conditioned on fps - 1, which is why it is reduced here.
# See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
fps = fps - 1

Expand Down Expand Up @@ -471,11 +481,11 @@ def __call__(
)
added_time_ids = added_time_ids.to(device)

# 4. Prepare timesteps
# 6. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps

# 5. Prepare latent variables
# 7. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
Expand All @@ -489,15 +499,15 @@ def __call__(
latents,
)

# 7. Prepare guidance scale
# 8. Prepare guidance scale
guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
guidance_scale = guidance_scale.to(device, latents.dtype)
guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
guidance_scale = _append_dims(guidance_scale, latents.ndim)

self._guidance_scale = guidance_scale

# 8. Denoising loop
# 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
Expand All @@ -506,7 +516,7 @@ def __call__(
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# Concatenate image_latents over channels dimention
# Concatenate image_latents over channels dimension
latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)

# predict the noise residual
Expand Down