Skip to content

Commit 63f767e

Browse files
patil-surajDN6patrickvonplatenapolinario
authored
Add SVD (#5895)
* begin model * finish blocks * add_embedding * addition_time_embed_dim * use TimestepEmbedding * fix temporal res block * fix time_pos_embed * fix add_embedding * add conversion script * fix model * up * add new resnet blocks * make forward work * return sample in original shape * fix temb shape in TemporalResnetBlock * add spatio temporal transformers * add vae blocks * fix blocks * update * update * fix shapes in Alphablender and add time activation in res blcok * use new blocks * style * fix temb shape * fix SpatioTemporalResBlock * reuse TemporalBasicTransformerBlock * fix TemporalBasicTransformerBlock * use TransformerSpatioTemporalModel * fix TransformerSpatioTemporalModel * fix time_context dim * clean up * make temb optional * add blocks * rename model * update conversion script * remove UNetMidBlockSpatioTemporal * add in init * remove unused arg * remove unused arg * remove more unsed args * up * up * check for None * update vae * update up/mid blocks for decoder * begin pipeline * adapt scheduler * add guidance scalings * fix norm eps in temporal transformers * add temporal autoencoder * make pipeline run * fix frame decodig * decode in float32 * decode n frames at a time * pass decoding_t to decode_latents * fix decode_latents * vae encode/decode in fp32 * fix dtype in TransformerSpatioTemporalModel * type image_latents same as image_embeddings * allow using differnt eps in temporal block for video decoder * fix default values in vae * pass num frames in decode * switch spatial to temporal for mixing in VAE * fix num frames during split decoding * cast alpha to sample dtype * fix attention in MidBlockTemporalDecoder * fix typo * fix guidance_scales dtype * fix missing activation in TemporalDecoder * skip_post_quant_conv * add vae conversion * style * take guidance scale as input * up * allow passing PIL to export_video * accept fps as arg * add pipeline and vae in init * remove hack * use AutoencoderKLTemporalDecoder * don't scale image latents * add unet tests * clean up unet * clean TransformerSpatioTemporalModel * add slow svd test * clean up * make temb optional in Decoder mid block * fix norm eps in TransformerSpatioTemporalModel * clean up temp decoder * clean up * clean up * use c_noise values for timesteps * use math for log * update * fix copies * doc * upcast vae * update forward pass for gradient checkpointing * make added_time_ids is tensor * up * fix upcasting * remove post quant conv * add _resize_with_antialiasing * fix _compute_padding * cleanup model * more cleanup * more cleanup * more cleanup * remove freeu * remove attn slice * small clean * up * up * remove extra step kwargs * remove eta * remove dropout * remove callback * remove merge factor args * clean * clean up * move to dedicated folder * remove attention_head_dim * docstr and small fix * update unet doc strings * rename decoding_t * correct linting * store c_skip and c_out * cleanup * clean TemporalResnetBlock * more cleanup * clean up vae * clean up * begin doc * more cleanup * up * up * doc * Improve * better naming * better naming * better naming * better naming * better naming * better naming * better naming * better naming * Apply suggestions from code review * Default chunk size to None * add example * Better * Apply suggestions from code review * update doc * Update src/diffusers/pipelines/stable_diffusion_video/pipeline_stable_diffusion_video.py Co-authored-by: Patrick von Platen <[email protected]> * style * Get torch compile working * up * rename * fix doc * add chunking * torch compile * torch compile * add modelling outputs * torch compile * Improve chunking * Apply suggestions from code review * Update docs/source/en/using-diffusers/svd.md * Close diff tag * remove slicing * resnet docstr * add docstr in resnet * rename * Apply suggestions from code review * update tests * Fix output type latents * fix more * fix more * Update docs/source/en/using-diffusers/svd.md * fix more * add pipeline tests * remove unused arg * clean up * make sure get_scaling receives tensors * fix euler scheduler * fix get_scalings * simply euler for now * remove old test file * use randn_tensor to create noise * fix device for rand tensor * increase expected_max_difference * fix test_inference_batch_single_identical * actually fix test_inference_batch_single_identical * disable test_save_load_float16 * skip test_float16_inference * skip test_inference_batch_single_identical * fix test_xformers_attention_forwardGenerator_pass * Apply suggestions from code review * update StableVideoDiffusionPipelineSlowTests * update image * add diffusers example * fix more --------- Co-authored-by: Dhruv Nair <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: apolinário <[email protected]>
1 parent d1b2a1a commit 63f767e

38 files changed

+5287
-149
lines changed

PHILOSOPHY.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ Models are designed as configurable toolboxes that are natural extensions of [Py
8282
The following design principles are followed:
8383
- Models correspond to **a type of model architecture**. *E.g.* the [`UNet2DConditionModel`] class is used for all UNet variations that expect 2D image inputs and are conditioned on some context.
8484
- All models can be found in [`src/diffusers/models`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and every model architecture shall be defined in its file, e.g. [`unet_2d_condition.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py), [`transformer_2d.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformer_2d.py), etc...
85-
- Models **do not** follow the single-file policy and should make use of smaller model building blocks, such as [`attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py), [`resnet.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py), [`embeddings.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py), etc... **Note**: This is in stark contrast to Transformers' modelling files and shows that models do not really follow the single-file policy.
85+
- Models **do not** follow the single-file policy and should make use of smaller model building blocks, such as [`attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py), [`resnet.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py), [`embeddings.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py), etc... **Note**: This is in stark contrast to Transformers' modeling files and shows that models do not really follow the single-file policy.
8686
- Models intend to expose complexity, just like PyTorch's `Module` class, and give clear error messages.
8787
- Models all inherit from `ModelMixin` and `ConfigMixin`.
8888
- Models can be optimized for performance when it doesn’t demand major code changes, keep backward compatibility, and give significant memory or compute gain.

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@
9494
title: Latent Consistency Model-LoRA
9595
- local: using-diffusers/inference_with_lcm
9696
title: Latent Consistency Model
97+
- local: using-diffusers/svd
98+
title: Stable Video Diffusion
9799
title: Specific pipeline examples
98100
- sections:
99101
- local: training/overview
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Stable Video Diffusion
14+
15+
[[open-in-colab]]
16+
17+
[Stable Video Diffusion](https://static1.squarespace.com/static/6213c340453c3f502425776e/t/655ce779b9d47d342a93c890/1700587395994/stable_video_diffusion.pdf) is a powerful image-to-video generation model that can generate high resolution (576x1024) 2-4 second videos conditioned on the input image.
18+
19+
This guide will show you how to use SVD to short generate videos from images.
20+
21+
Before you begin, make sure you have the following libraries installed:
22+
23+
```py
24+
!pip install -q -U diffusers transformers accelerate
25+
```
26+
27+
## Image to Video Generation
28+
29+
The are two variants of SVD. [SVD](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid)
30+
and [SVD-XT](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt). The svd checkpoint is trained to generate 14 frames and the svd-xt checkpoint is further
31+
finetuned to generate 25 frames.
32+
33+
We will use the `svd-xt` checkpoint for this guide.
34+
35+
```python
36+
import torch
37+
38+
from diffusers import StableVideoDiffusionPipeline
39+
from diffusers.utils import load_image, export_to_video
40+
41+
pipe = StableVideoDiffusionPipeline.from_pretrained(
42+
"stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16"
43+
)
44+
pipe.enable_model_cpu_offload()
45+
46+
# Load the conditioning image
47+
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png?download=true")
48+
image = image.resize((1024, 576))
49+
50+
generator = torch.manual_seed(42)
51+
frames = pipe(image, decode_chunk_size=8, generator=generator).frames[0]
52+
53+
export_to_video(frames, "generated.mp4", fps=7)
54+
```
55+
56+
<video width="1024" height="576" controls>
57+
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket_generated.mp4?download=true" type="video/mp4">
58+
</video>
59+
60+
<Tip>
61+
Since generating videos is more memory intensive we can use the `decode_chunk_size` argument to control how many frames are decoded at once. This will reduce the memory usage. It's recommended to tweak this value based on your GPU memory.
62+
Setting `decode_chunk_size=1` will decode one frame at a time and will use the least amount of memory but the video might have some flickering.
63+
64+
Additionally, we also use [model cpu offloading](../../optimization/memory#model-offloading) to reduce the memory usage.
65+
</Tip>
66+
67+
68+
### Torch.compile
69+
70+
You can achieve a 20-25% speed-up at the expense of slightly increased memory by compiling the UNet as follows:
71+
72+
```diff
73+
- pipe.enable_model_cpu_offload()
74+
+ pipe.to("cuda")
75+
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
76+
```
77+
78+
### Low-memory
79+
80+
Video generation is very memory intensive as we have to essentially generate `num_frames` all at once. The mechanism is very comparable to text-to-image generation with a high batch size. To reduce the memory requirement you have multiple options. The following options trade inference speed against lower memory requirement:
81+
- enable model offloading: Each component of the pipeline is offloaded to CPU once it's not needed anymore.
82+
- enable feed-forward chunking: The feed-forward layer runs in a loop instead of running with a single huge feed-forward batch size
83+
- reduce `decode_chunk_size`: This means that the VAE decodes frames in chunks instead of decoding them all together. **Note**: In addition to leading to a small slowdown, this method also slightly leads to video quality deterioration
84+
85+
You can enable them as follows:
86+
```diff
87+
-pipe.enable_model_cpu_offload()
88+
-frames = pipe(image, decode_chunk_size=8, generator=generator).frames[0]
89+
+pipe.enable_model_cpu_offload()
90+
+pipe.unet.enable_forward_chunking()
91+
+frames = pipe(image, decode_chunk_size=2, generator=generator, num_frames=25).frames[0]
92+
```
93+
94+
95+
Including all these tricks should lower the memory requirement to less than 8GB VRAM.
96+
97+
### Micro-conditioning
98+
99+
Along with conditioning image Stable Diffusion Video also allows providing micro-conditioning that allows more control over the generated video.
100+
It accepts the following arguments:
101+
102+
- `fps`: The frames per second of the generated video.
103+
- `motion_bucket_id`: The motion bucket id to use for the generated video. This can be used to control the motion of the generated video. Increasing the motion bucket id will increase the motion of the generated video.
104+
- `noise_aug_strength`: The amount of noise added to the conditioning image. The higher the values the less the video will resemble the conditioning image. Increasing this value will also increase the motion of the generated video.
105+
106+
Here is an example of using micro-conditioning to generate a video with more motion.
107+
108+
```python
109+
import torch
110+
111+
from diffusers import StableVideoDiffusionPipeline
112+
from diffusers.utils import load_image, export_to_video
113+
114+
pipe = StableVideoDiffusionPipeline.from_pretrained(
115+
"stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16"
116+
)
117+
pipe.enable_model_cpu_offload()
118+
119+
# Load the conditioning image
120+
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png?download=true")
121+
image = image.resize((1024, 576))
122+
123+
generator = torch.manual_seed(42)
124+
frames = pipe(image, decode_chunk_size=8, generator=generator, motion_bucket_id=180, noise_aug_strength=0.1).frames[0]
125+
export_to_video(frames, "generated.mp4", fps=7)
126+
```
127+
128+
<video width="1024" height="576" controls>
129+
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket_generated_motion.mp4?download=true" type="video/mp4">
130+
</video>
131+

0 commit comments

Comments
 (0)