Skip to content

Commit ba49272

Browse files
19and99sayakpaul
andauthored
[Pipeline] Add TextToVideoZeroPipeline (#2954)
* add TextToVideoZeroPipeline and CrossFrameAttnProcessor * add docs for text-to-video zero * add teaser image for text-to-video zero docs * Fix review changes. Add Documentation. Add test * clean up the codes in pipeline_text_to_video.py. Add descriptive comments and docstrings * make style && make quality * make fix-copies * make requested changes to docs. use huggingface server links for resources, delete res folder * make style && make quality && make fix-copies * make style && make quality * Apply suggestions from code review --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 074d281 commit ba49272

File tree

9 files changed

+837
-1
lines changed

9 files changed

+837
-1
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,8 @@
206206
title: Stochastic Karras VE
207207
- local: api/pipelines/text_to_video
208208
title: Text-to-Video
209+
- local: api/pipelines/text_to_video_zero
210+
title: Text-to-Video Zero
209211
- local: api/pipelines/unclip
210212
title: UnCLIP
211213
- local: api/pipelines/latent_diffusion_uncond

docs/source/en/api/pipelines/overview.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ available a colab notebook to directly try them out.
8383
| [versatile_diffusion](./versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation |
8484
| [versatile_diffusion](./versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation |
8585
| [vq_diffusion](./vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
86+
| [text_to_video_zero](./text_to_video_zero) | [Text2Video-Zero: Text-to-Image Diffusion Models are Zero-Shot Video Generators](https://arxiv.org/abs/2303.13439) | Text-to-Video Generation |
8687

8788

8889
**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers.
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Zero-Shot Text-to-Video Generation
14+
15+
## Overview
16+
17+
18+
[Text2Video-Zero: Text-to-Image Diffusion Models are Zero-Shot Video Generators](https://arxiv.org/abs/2303.13439) by
19+
Levon Khachatryan,
20+
Andranik Movsisyan,
21+
Vahram Tadevosyan,
22+
Roberto Henschel,
23+
[Zhangyang Wang](https://www.ece.utexas.edu/people/faculty/atlas-wang), Shant Navasardyan, [Humphrey Shi](https://www.humphreyshi.com).
24+
25+
Our method Text2Video-Zero enables zero-shot video generation using either
26+
1. A textual prompt, or
27+
2. A prompt combined with guidance from poses or edges, or
28+
3. Video Instruct-Pix2Pix, i.e., instruction-guided video editing.
29+
30+
Results are temporally consistent and follow closely the guidance and textual prompts.
31+
32+
![teaser-img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/t2v_zero_teaser.png)
33+
34+
The abstract of the paper is the following:
35+
36+
*Recent text-to-video generation approaches rely on computationally heavy training and require large-scale video datasets. In this paper, we introduce a new task of zero-shot text-to-video generation and propose a low-cost approach (without any training or optimization) by leveraging the power of existing text-to-image synthesis methods (e.g., Stable Diffusion), making them suitable for the video domain.
37+
Our key modifications include (i) enriching the latent codes of the generated frames with motion dynamics to keep the global scene and the background time consistent; and (ii) reprogramming frame-level self-attention using a new cross-frame attention of each frame on the first frame, to preserve the context, appearance, and identity of the foreground object.
38+
Experiments show that this leads to low overhead, yet high-quality and remarkably consistent video generation. Moreover, our approach is not limited to text-to-video synthesis but is also applicable to other tasks such as conditional and content-specialized video generation, and Video Instruct-Pix2Pix, i.e., instruction-guided video editing.
39+
As experiments show, our method performs comparably or sometimes better than recent approaches, despite not being trained on additional video data.*
40+
41+
42+
43+
Resources:
44+
45+
* [Project Page](https://text2video-zero.github.io/)
46+
* [Paper](https://arxiv.org/abs/2303.13439)
47+
* [Original Code](https://github.com/Picsart-AI-Research/Text2Video-Zero)
48+
49+
50+
## Available Pipelines:
51+
52+
| Pipeline | Tasks | Demo
53+
|---|---|:---:|
54+
| [TextToVideoZeroPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py) | *Zero-shot Text-to-Video Generation* | [🤗 Space](https://huggingface.co/spaces/PAIR/Text2Video-Zero)
55+
56+
57+
## Usage example
58+
59+
### Text-To-Video
60+
61+
To generate a video from prompt, run the following python command
62+
```python
63+
import torch
64+
from diffusers import TextToVideoZeroPipeline
65+
66+
model_id = "runwayml/stable-diffusion-v1-5"
67+
pipe = TextToVideoZeroPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
68+
69+
prompt = "A panda is playing guitar on times square"
70+
result = pipe(prompt=prompt).images
71+
imageio.mimsave("video.mp4", result, fps=4)
72+
```
73+
You can change these parameters in the pipeline call:
74+
* Motion field strength (see the [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1):
75+
* `motion_field_strength_x` and `motion_field_strength_y`. Default: `motion_field_strength_x=12`, `motion_field_strength_y=12`
76+
* `T` and `T'` (see the [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1)
77+
* `t0` and `t1` in the range `{0, ..., num_inference_steps}`. Default: `t0=45`, `t1=48`
78+
* Video length:
79+
* `video_length`, the number of frames video_length to be generated. Default: `video_length=8`
80+
81+
82+
### Text-To-Video with Pose Control
83+
To generate a video from prompt with additional pose control
84+
85+
1. Download a demo video
86+
87+
```python
88+
from huggingface_hub import hf_hub_download
89+
90+
filename = "__assets__/poses_skeleton_gifs/dance1_corr.mp4"
91+
repo_id = "PAIR/Text2Video-Zero"
92+
video_path = hf_hub_download(repo_type="space", repo_id=repo_id, filename=filename)
93+
```
94+
95+
96+
2. Read video containing extracted pose images
97+
```python
98+
import imageio
99+
100+
reader = imageio.get_reader(video_path, "ffmpeg")
101+
frame_count = 8
102+
pose_images = [Image.fromarray(reader.get_data(i)) for i in range(frame_count)]
103+
```
104+
To extract pose from actual video, read [ControlNet documentation](./stable_diffusion/controlnet).
105+
106+
3. Run `StableDiffusionControlNetPipeline` with our custom attention processor
107+
108+
```python
109+
import torch
110+
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
111+
from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero import CrossFrameAttnProcessor
112+
113+
model_id = "runwayml/stable-diffusion-v1-5"
114+
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16)
115+
pipe = StableDiffusionControlNetPipeline.from_pretrained(
116+
model_id, controlnet=controlnet, torch_dtype=torch.float16
117+
).to("cuda")
118+
119+
# Set the attention processor
120+
pipe.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
121+
pipe.controlnet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
122+
123+
# fix latents for all frames
124+
latents = torch.randn((1, 4, 64, 64), device="cuda", dtype=torch.float16).repeat(len(pose_images), 1, 1, 1)
125+
126+
prompt = "Darth Vader dancing in a desert"
127+
result = pipe(prompt=[prompt] * len(pose_images), image=pose_images, latents=latents).images
128+
imageio.mimsave("video.mp4", result, fps=4)
129+
```
130+
131+
132+
### Text-To-Video with Edge Control
133+
134+
To generate a video from prompt with additional pose control,
135+
follow the steps described above for pose-guided generation using [Canny edge ControlNet model](https://huggingface.co/lllyasviel/sd-controlnet-canny).
136+
137+
138+
### Video Instruct-Pix2Pix
139+
140+
To perform text-guided video editing (with [InstructPix2Pix](./stable_diffusion/pix2pix)):
141+
142+
1. Download a demo video
143+
144+
```python
145+
from huggingface_hub import hf_hub_download
146+
147+
filename = "__assets__/pix2pix video/camel.mp4"
148+
repo_id = "PAIR/Text2Video-Zero"
149+
video_path = hf_hub_download(repo_type="space", repo_id=repo_id, filename=filename)
150+
```
151+
152+
2. Read video from path
153+
```python
154+
import imageio
155+
156+
reader = imageio.get_reader(video_path, "ffmpeg")
157+
frame_count = 8
158+
video = [Image.fromarray(reader.get_data(i)) for i in range(frame_count)]
159+
```
160+
161+
3. Run `StableDiffusionInstructPix2PixPipeline` with our custom attention processor
162+
```python
163+
import torch
164+
from diffusers import StableDiffusionInstructPix2PixPipeline
165+
from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero import CrossFrameAttnProcessor
166+
167+
model_id = "timbrooks/instruct-pix2pix"
168+
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
169+
pipe.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=3))
170+
171+
prompt = "make it Van Gogh Starry Night style"
172+
result = pipe(prompt=[prompt] * len(video), image=video).images
173+
imageio.mimsave("edited_video.mp4", result, fps=4)
174+
```
175+
176+
177+
### Dreambooth specialization
178+
179+
Methods **Text-To-Video**, **Text-To-Video with Pose Control** and **Text-To-Video with Edge Control**
180+
can run with custom [DreamBooth](../training/dreambooth) models, as shown below for
181+
[Canny edge ControlNet model](https://huggingface.co/lllyasviel/sd-controlnet-canny) and
182+
[Avatar style DreamBooth](https://huggingface.co/PAIR/text2video-zero-controlnet-canny-avatar) model
183+
184+
1. Download demo video from huggingface
185+
186+
```python
187+
from huggingface_hub import hf_hub_download
188+
189+
filename = "__assets__/canny_videos_mp4/girl_turning.mp4"
190+
repo_id = "PAIR/Text2Video-Zero"
191+
video_path = hf_hub_download(repo_type="space", repo_id=repo_id, filename=filename)
192+
```
193+
194+
2. Read video from path
195+
```python
196+
import imageio
197+
198+
reader = imageio.get_reader(video_path, "ffmpeg")
199+
frame_count = 8
200+
video = [Image.fromarray(reader.get_data(i)) for i in range(frame_count)]
201+
```
202+
203+
3. Run `StableDiffusionControlNetPipeline` with custom trained DreamBooth model
204+
```python
205+
import torch
206+
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
207+
from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero import CrossFrameAttnProcessor
208+
209+
# set model id to custom model
210+
model_id = "PAIR/text2video-zero-controlnet-canny-avatar"
211+
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
212+
pipe = StableDiffusionControlNetPipeline.from_pretrained(
213+
model_id, controlnet=controlnet, torch_dtype=torch.float16
214+
).to("cuda")
215+
216+
# Set the attention processor
217+
pipe.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
218+
pipe.controlnet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
219+
220+
# fix latents for all frames
221+
latents = torch.randn((1, 4, 64, 64), device="cuda", dtype=torch.float16).repeat(len(pose_images), 1, 1, 1)
222+
223+
prompt = "oil painting of a beautiful girl avatar style"
224+
result = pipe(prompt=[prompt] * len(pose_images), image=pose_images, latents=latents).images
225+
imageio.mimsave("video.mp4", result, fps=4)
226+
```
227+
228+
You can filter out some available DreamBooth-trained models with [this link](https://huggingface.co/models?search=dreambooth).
229+
230+
231+
232+
## TextToVideoZeroPipeline
233+
[[autodoc]] TextToVideoZeroPipeline
234+
- all
235+
- __call__

src/diffusers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@
137137
StableUnCLIPImg2ImgPipeline,
138138
StableUnCLIPPipeline,
139139
TextToVideoSDPipeline,
140+
TextToVideoZeroPipeline,
140141
UnCLIPImageVariationPipeline,
141142
UnCLIPPipeline,
142143
VersatileDiffusionDualGuidedPipeline,

src/diffusers/pipelines/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
StableUnCLIPPipeline,
6969
)
7070
from .stable_diffusion_safe import StableDiffusionPipelineSafe
71-
from .text_to_video_synthesis import TextToVideoSDPipeline
71+
from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline
7272
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
7373
from .versatile_diffusion import (
7474
VersatileDiffusionDualGuidedPipeline,

src/diffusers/pipelines/text_to_video_synthesis/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,4 @@ class TextToVideoSDPipelineOutput(BaseOutput):
2929
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
3030
else:
3131
from .pipeline_text_to_video_synth import TextToVideoSDPipeline # noqa: F401
32+
from .pipeline_text_to_video_zero import TextToVideoZeroPipeline

0 commit comments

Comments
 (0)