Skip to content

Commit 96f6aee

Browse files
DN6Jimmy
authored andcommitted
Add PIA Model/Pipeline (huggingface#6698)
* update * update * updaet * add tests and docs * clean up * add to toctree * fix copies * pr review feedback * fix copies * fix tests * update docs * update * update * update docs * update * update * update * update
1 parent 05859a1 commit 96f6aee

File tree

10 files changed

+1762
-1
lines changed

10 files changed

+1762
-1
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,8 @@
302302
title: MusicLDM
303303
- local: api/pipelines/paint_by_example
304304
title: Paint by Example
305+
- local: api/pipelines/pia
306+
title: Personalized Image Animator (PIA)
305307
- local: api/pipelines/pixart
306308
title: PixArt-α
307309
- local: api/pipelines/self_attention_guidance
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Image-to-Video Generation with PIA (Personalized Image Animator)
14+
15+
## Overview
16+
17+
[PIA: Your Personalized Image Animator via Plug-and-Play Modules in Text-to-Image Models](https://arxiv.org/abs/2312.13964) by Yiming Zhang, Zhening Xing, Yanhong Zeng, Youqing Fang, Kai Chen
18+
19+
Recent advancements in personalized text-to-image (T2I) models have revolutionized content creation, empowering non-experts to generate stunning images with unique styles. While promising, adding realistic motions into these personalized images by text poses significant challenges in preserving distinct styles, high-fidelity details, and achieving motion controllability by text. In this paper, we present PIA, a Personalized Image Animator that excels in aligning with condition images, achieving motion controllability by text, and the compatibility with various personalized T2I models without specific tuning. To achieve these goals, PIA builds upon a base T2I model with well-trained temporal alignment layers, allowing for the seamless transformation of any personalized T2I model into an image animation model. A key component of PIA is the introduction of the condition module, which utilizes the condition frame and inter-frame affinity as input to transfer appearance information guided by the affinity hint for individual frame synthesis in the latent space. This design mitigates the challenges of appearance-related image alignment within and allows for a stronger focus on aligning with motion-related guidance.
20+
21+
[Project page](https://pi-animator.github.io/)
22+
23+
## Available Pipelines
24+
25+
| Pipeline | Tasks | Demo
26+
|---|---|:---:|
27+
| [PIAPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pia/pipeline_pia.py) | *Image-to-Video Generation with PIA* |
28+
29+
## Available checkpoints
30+
31+
Motion Adapter checkpoints for PIA can be found under the [OpenMMLab org](https://huggingface.co/openmmlab/PIA-condition-adapter). These checkpoints are meant to work with any model based on Stable Diffusion 1.5
32+
33+
## Usage example
34+
35+
PIA works with a MotionAdapter checkpoint and a Stable Diffusion 1.5 model checkpoint. The MotionAdapter is a collection of Motion Modules that are responsible for adding coherent motion across image frames. These modules are applied after the Resnet and Attention blocks in the Stable Diffusion UNet. In addition to the motion modules, PIA also replaces the input convolution layer of the SD 1.5 UNet model with a 9 channel input convolution layer.
36+
37+
The following example demonstrates how to use PIA to generate a video from a single image.
38+
39+
```python
40+
import torch
41+
from diffusers import (
42+
EulerDiscreteScheduler,
43+
MotionAdapter,
44+
PIAPipeline,
45+
)
46+
from diffusers.utils import export_to_gif, load_image
47+
48+
adapter = MotionAdapter.from_pretrained("openmmlab/PIA-condition-adapter")
49+
pipe = PIAPipeline.from_pretrained("SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter, torch_dtype=torch.float16)
50+
51+
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
52+
pipe.enable_model_cpu_offload()
53+
pipe.enable_vae_slicing()
54+
55+
image = load_image(
56+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true"
57+
)
58+
image = image.resize((512, 512))
59+
prompt = "cat in a field"
60+
negative_prompt = "wrong white balance, dark, sketches,worst quality,low quality"
61+
62+
generator = torch.Generator("cpu").manual_seed(0)
63+
output = pipe(image=image, prompt=prompt, generator=generator)
64+
frames = output.frames[0]
65+
export_to_gif(frames, "pia-animation.gif")
66+
```
67+
68+
Here are some sample outputs:
69+
70+
<table>
71+
<tr>
72+
<td><center>
73+
masterpiece, bestquality, sunset.
74+
<br>
75+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/pia-default-output.gif"
76+
alt="cat in a field"
77+
style="width: 300px;" />
78+
</center></td>
79+
</tr>
80+
</table>
81+
82+
83+
<Tip>
84+
85+
If you plan on using a scheduler that can clip samples, make sure to disable it by setting `clip_sample=False` in the scheduler as this can also have an adverse effect on generated samples. Additionally, the PIA checkpoints can be sensitive to the beta schedule of the scheduler. We recommend setting this to `linear`.
86+
87+
</Tip>
88+
89+
## Using FreeInit
90+
91+
[FreeInit: Bridging Initialization Gap in Video Diffusion Models](https://arxiv.org/abs/2312.07537) by Tianxing Wu, Chenyang Si, Yuming Jiang, Ziqi Huang, Ziwei Liu.
92+
93+
FreeInit is an effective method that improves temporal consistency and overall quality of videos generated using video-diffusion-models without any addition training. It can be applied to PIA, AnimateDiff, ModelScope, VideoCrafter and various other video generation models seamlessly at inference time, and works by iteratively refining the latent-initialization noise. More details can be found it the paper.
94+
95+
The following example demonstrates the usage of FreeInit.
96+
97+
```python
98+
import torch
99+
from diffusers import (
100+
DDIMScheduler,
101+
MotionAdapter,
102+
PIAPipeline,
103+
)
104+
from diffusers.utils import export_to_gif, load_image
105+
106+
adapter = MotionAdapter.from_pretrained("openmmlab/PIA-condition-adapter")
107+
pipe = PIAPipeline.from_pretrained("SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter)
108+
109+
# enable FreeInit
110+
# Refer to the enable_free_init documentation for a full list of configurable parameters
111+
pipe.enable_free_init(method="butterworth", use_fast_sampling=True)
112+
113+
# Memory saving options
114+
pipe.enable_model_cpu_offload()
115+
pipe.enable_vae_slicing()
116+
117+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
118+
image = load_image(
119+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true"
120+
)
121+
image = image.resize((512, 512))
122+
prompt = "cat in a hat"
123+
negative_prompt = "wrong white balance, dark, sketches,worst quality,low quality"
124+
125+
generator = torch.Generator("cpu").manual_seed(0)
126+
127+
output = pipe(image=image, prompt=prompt, generator=generator)
128+
frames = output.frames[0]
129+
export_to_gif(frames, "pia-freeinit-animation.gif")
130+
```
131+
132+
<table>
133+
<tr>
134+
<td><center>
135+
masterpiece, bestquality, sunset.
136+
<br>
137+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/pia-freeinit-output-cat.gif"
138+
alt="cat in a field"
139+
style="width: 300px;" />
140+
</center></td>
141+
</tr>
142+
</table>
143+
144+
145+
<Tip warning={true}>
146+
147+
FreeInit is not really free - the improved quality comes at the cost of extra computation. It requires sampling a few extra times depending on the `num_iters` parameter that is set when enabling it. Setting the `use_fast_sampling` parameter to `True` can improve the overall performance (at the cost of lower quality compared to when `use_fast_sampling=False` but still better results than vanilla video generation models).
148+
149+
</Tip>
150+
151+
## PIAPipeline
152+
153+
[[autodoc]] PIAPipeline
154+
- all
155+
- __call__
156+
- enable_freeu
157+
- disable_freeu
158+
- enable_free_init
159+
- disable_free_init
160+
- enable_vae_slicing
161+
- disable_vae_slicing
162+
- enable_vae_tiling
163+
- disable_vae_tiling
164+
165+
## PIAPipelineOutput
166+
167+
[[autodoc]] pipelines.pia.PIAPipelineOutput

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@
248248
"LDMTextToImagePipeline",
249249
"MusicLDMPipeline",
250250
"PaintByExamplePipeline",
251+
"PIAPipeline",
251252
"PixArtAlphaPipeline",
252253
"SemanticStableDiffusionPipeline",
253254
"ShapEImg2ImgPipeline",
@@ -608,6 +609,7 @@
608609
LDMTextToImagePipeline,
609610
MusicLDMPipeline,
610611
PaintByExamplePipeline,
612+
PIAPipeline,
611613
PixArtAlphaPipeline,
612614
SemanticStableDiffusionPipeline,
613615
ShapEImg2ImgPipeline,

src/diffusers/models/unets/unet_motion_model.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def __init__(
8989
motion_norm_num_groups: int = 32,
9090
motion_max_seq_length: int = 32,
9191
use_motion_mid_block: bool = True,
92+
conv_in_channels: Optional[int] = None,
9293
):
9394
"""Container to store AnimateDiff Motion Modules
9495
@@ -113,6 +114,12 @@ def __init__(
113114
down_blocks = []
114115
up_blocks = []
115116

117+
if conv_in_channels:
118+
# input
119+
self.conv_in = nn.Conv2d(conv_in_channels, block_out_channels[0], kernel_size=3, padding=1)
120+
else:
121+
self.conv_in = None
122+
116123
for i, channel in enumerate(block_out_channels):
117124
output_channel = block_out_channels[i]
118125
down_blocks.append(
@@ -410,6 +417,10 @@ def from_unet2d(
410417
config["motion_max_seq_length"] = motion_adapter.config["motion_max_seq_length"]
411418
config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"]
412419

420+
# For PIA UNets we need to set the number input channels to 9
421+
if motion_adapter.config["conv_in_channels"]:
422+
config["in_channels"] = motion_adapter.config["conv_in_channels"]
423+
413424
# Need this for backwards compatibility with UNet2DConditionModel checkpoints
414425
if not config.get("num_attention_heads"):
415426
config["num_attention_heads"] = config["attention_head_dim"]
@@ -419,7 +430,17 @@ def from_unet2d(
419430
if not load_weights:
420431
return model
421432

422-
model.conv_in.load_state_dict(unet.conv_in.state_dict())
433+
# Logic for loading PIA UNets which allow the first 4 channels to be any UNet2DConditionModel conv_in weight
434+
# while the last 5 channels must be PIA conv_in weights.
435+
if has_motion_adapter and motion_adapter.config["conv_in_channels"]:
436+
model.conv_in = motion_adapter.conv_in
437+
updated_conv_in_weight = torch.cat(
438+
[unet.conv_in.weight, motion_adapter.conv_in.weight[:, 4:, :, :]], dim=1
439+
)
440+
model.conv_in.load_state_dict({"weight": updated_conv_in_weight, "bias": unet.conv_in.bias})
441+
else:
442+
model.conv_in.load_state_dict(unet.conv_in.state_dict())
443+
423444
model.time_proj.load_state_dict(unet.time_proj.state_dict())
424445
model.time_embedding.load_state_dict(unet.time_embedding.state_dict())
425446

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@
171171
_import_structure["latent_diffusion"].extend(["LDMTextToImagePipeline"])
172172
_import_structure["musicldm"] = ["MusicLDMPipeline"]
173173
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
174+
_import_structure["pia"] = ["PIAPipeline"]
174175
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline"]
175176
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
176177
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
@@ -415,6 +416,7 @@
415416
from .latent_diffusion import LDMTextToImagePipeline
416417
from .musicldm import MusicLDMPipeline
417418
from .paint_by_example import PaintByExamplePipeline
419+
from .pia import PIAPipeline
418420
from .pixart_alpha import PixArtAlphaPipeline
419421
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
420422
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from typing import TYPE_CHECKING
2+
3+
from ...utils import (
4+
DIFFUSERS_SLOW_IMPORT,
5+
OptionalDependencyNotAvailable,
6+
_LazyModule,
7+
get_objects_from_module,
8+
is_torch_available,
9+
is_transformers_available,
10+
)
11+
12+
13+
_dummy_objects = {}
14+
_import_structure = {}
15+
16+
try:
17+
if not (is_transformers_available() and is_torch_available()):
18+
raise OptionalDependencyNotAvailable()
19+
except OptionalDependencyNotAvailable:
20+
from ...utils import dummy_torch_and_transformers_objects
21+
22+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
23+
else:
24+
_import_structure["pipeline_pia"] = ["PIAPipeline", "PIAPipelineOutput"]
25+
26+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
27+
try:
28+
if not (is_transformers_available() and is_torch_available()):
29+
raise OptionalDependencyNotAvailable()
30+
except OptionalDependencyNotAvailable:
31+
from ...utils.dummy_torch_and_transformers_objects import *
32+
33+
else:
34+
from .pipeline_pia import PIAPipeline, PIAPipelineOutput
35+
36+
else:
37+
import sys
38+
39+
sys.modules[__name__] = _LazyModule(
40+
__name__,
41+
globals()["__file__"],
42+
_import_structure,
43+
module_spec=__spec__,
44+
)
45+
for name, value in _dummy_objects.items():
46+
setattr(sys.modules[__name__], name, value)

0 commit comments

Comments
 (0)