Skip to content

Conversation

@reallyigor
Copy link
Contributor

@reallyigor reallyigor commented Sep 1, 2024

What does this PR do?

This PR adds ControlNet support to the Video To Video AnimateDiff.

Fixes:

See #9326 [Pipeline] AnimateDiff + VideoToVideo + ControlNet #9326

Results:

Default pipeline:

ControlNet Strength 0.1 Strength 0.5 Strength 0.8
Depth Strength 0.1 Strength 0.5 Strength 0.8
OpenPose Strength 0.1_2 Strength 0.5_2 Strength 0.8_2

An example with IPAdapter usage alongside ControlNet

Input IP Adapter Result Without IP Adapter
austranaut ip-adapter-plus-_sd15 animatediff_vid2vid_controlnet_0 8_ipa_2 animatediff_vid2vid_controlnet_0 8_2
woman ip-adapter-plus-face_sd15 animatediff_vid2vid_controlnet_0 8_ipa animatediff_vid2vid_controlnet_0 8_2

Prompt travel on a tik-tok dance video

strength = 0.8
pipe.set_ip_adapter_scale(0.0)

context_length = 16
context_stride = 4
pipe.enable_free_noise(context_length=context_length, context_stride=context_stride)

# Can be a single prompt, or a dictionary with frame timesteps
prompt = {
    0: "a girl on a winter day, sparkly leaves in the background, snow flakes, close up",
    10: "a girl on a autumn day, yellow leaves in the background, close up",
    20: "a girl on a rainy day, tropical leaves in the background, close up",
}
negative_prompt = "bad quality, worst quality"

with torch.inference_mode():
    video = pipe(
        video=video,
        prompt=prompt,
        negative_prompt=negative_prompt,
        num_inference_steps=10,
        guidance_scale=2.0,
        controlnet_conditioning_scale=0.75,
        conditioning_frames=conditioning_frames,
        strength=strength,
        generator=torch.Generator().manual_seed(42),
        ip_adapter_image=ip_adapter_image,
    ).frames[0]

prompt_travel_0 8_2

How to test:

import torch
from PIL import Image
from tqdm.auto import tqdm

from controlnet_aux.processor import OpenposeDetector
from diffusers import AnimateDiffVideoToVideoControlNetPipeline
from diffusers.utils import export_to_gif, load_video
from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter, LCMScheduler

controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16)
motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)

pipe = AnimateDiffVideoToVideoControlNetPipeline.from_pretrained(
    "SG161222/Realistic_Vision_V5.1_noVAE",
    motion_adapter=motion_adapter,
    controlnet=controlnet,
    vae=vae,
).to(device="cuda", dtype=torch.float16)

pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora")
pipe.set_adapters(["lcm-lora"], [0.8])

video = load_video("dance.gif")
video = [frame.convert("RGB") for frame in video]

prompt = "astronaut in space, dancing"
negative_prompt = "bad quality, worst quality, jpeg artifacts, ugly"

open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators").to("cuda")

conditioning_frames = []
for frame in tqdm(video):
    conditioning_frames.append(open_pose(frame))

strength = 0.8
with torch.inference_mode():
    video = pipe(
        video=video,
        prompt=prompt,
        negative_prompt=negative_prompt,
        num_inference_steps=10,
        guidance_scale=2.0,
        controlnet_conditioning_scale=0.75,
        conditioning_frames=conditioning_frames,
        strength=strength,
        generator=torch.Generator().manual_seed(42),
    ).frames[0]

video = [frame.resize(conditioning_frames[0].size) for frame in video]
export_to_gif(video, f"animatediff_vid2vid_controlnet.gif", fps=8)

Before submitting

I created new tests by adapting ones from the regular AnimateDiff Video-to-Video pipeline.

pytest tests/pipelines/animatediff/test_animatediff_video2video_controlnet.py

My code fails only on one test:

test_from_pipe_consistent_forward_pass_cpu_offload
========================================= 1 failed, 37 passed, 2 skipped, 37 warnings in 42.66s ==========================================

The problem is that the original (that is already in Diffusers) AnimateDiff Video To Video also fails on this test. I couldn't identify the problem. These are the results of the original AnimateDiff Video To Video:

pytest tests/pipelines/animatediff/test_animatediff_video2video.py 
========================================= 1 failed, 37 passed, 2 skipped, 37 warnings in 40.86s ==========================================

Who can review?

@DN6 @a-r-r-o-w

@reallyigor reallyigor marked this pull request as ready for review September 1, 2024 16:25
@reallyigor reallyigor changed the title Add animatediff + vid2vide + controlnet [Pipeline] animatediff + vid2vide + controlnet Sep 1, 2024
@reallyigor reallyigor changed the title [Pipeline] animatediff + vid2vide + controlnet [Pipeline] animatediff + vid2vid + controlnet Sep 1, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much, this is looking really good! I would like to and appreciate seeing a few experimental results though, just to verify functionality:

  • An example making use of enforce_inference_steps=True, possibly latent upscale from the other reference PRs or anything unique of your choice
  • An example making use of the FreeNoise prompt travel feature
  • An example with IPAdapter usage alongside ControlNet

By looking at the diffs, all changes look great to me 💯

@reallyigor
Copy link
Contributor Author

reallyigor commented Sep 2, 2024

An example with IPAdapter usage alongside ControlNet

@a-r-r-o-w

Input IP Adapter Result Without IP Adapter
austranaut ip-adapter-plus-_sd15 animatediff_vid2vid_controlnet_0 8_ipa_2 animatediff_vid2vid_controlnet_0 8_2
woman ip-adapter-plus-face_sd15 animatediff_vid2vid_controlnet_0 8_ipa animatediff_vid2vid_controlnet_0 8_2

@reallyigor
Copy link
Contributor Author

Prompt travel

strength = 0.8
pipe.set_ip_adapter_scale(0.0)

context_length = 16
context_stride = 4
pipe.enable_free_noise(context_length=context_length, context_stride=context_stride)

# Can be a single prompt, or a dictionary with frame timesteps
prompt = {
    0: "an austonaut on a winter day, sparkly leaves in the background, snow flakes, close up",
    6: "an austonaut on a autumn day, yellow leaves in the background, close up",
    12: "an austonaut on a rainy day, tropical leaves in the background, close up",
}
negative_prompt = "bad quality, worst quality"

with torch.inference_mode():
    video = pipe(
        video=video,
        prompt=prompt,
        negative_prompt=negative_prompt,
        num_inference_steps=10,
        guidance_scale=2.0,
        controlnet_conditioning_scale=0.75,
        conditioning_frames=conditioning_frames,
        strength=strength,
        generator=torch.Generator().manual_seed(42),
        ip_adapter_image=ip_adapter_image,
    ).frames[0]

prompt_travel

@reallyigor
Copy link
Contributor Author

Prompt travel on a tik-tok dance video

strength = 0.8
pipe.set_ip_adapter_scale(0.0)

context_length = 16
context_stride = 4
pipe.enable_free_noise(context_length=context_length, context_stride=context_stride)

# Can be a single prompt, or a dictionary with frame timesteps
prompt = {
    0: "a girl on a winter day, sparkly leaves in the background, snow flakes, close up",
    10: "a girl on a autumn day, yellow leaves in the background, close up",
    20: "a girl on a rainy day, tropical leaves in the background, close up",
}
negative_prompt = "bad quality, worst quality"

with torch.inference_mode():
    video = pipe(
        video=video,
        prompt=prompt,
        negative_prompt=negative_prompt,
        num_inference_steps=10,
        guidance_scale=2.0,
        controlnet_conditioning_scale=0.75,
        conditioning_frames=conditioning_frames,
        strength=strength,
        generator=torch.Generator().manual_seed(42),
        ip_adapter_image=ip_adapter_image,
    ).frames[0]

prompt_travel_0 8_2

@reallyigor
Copy link
Contributor Author

reallyigor commented Sep 2, 2024

Latent upscaling also works

strength = 0.8
pipe.set_ip_adapter_scale(0.0)

context_length = 16
context_stride = 4
pipe.enable_free_noise(context_length=context_length, context_stride=context_stride)

# Can be a single prompt, or a dictionary with frame timesteps
prompt = {
    0: "a girl on a winter day, sparkly leaves in the background, snow flakes, close up",
    10: "a girl on a autumn day, yellow leaves in the background, close up",
    20: "a girl on a rainy day, tropical leaves in the background, close up",
}
negative_prompt = "bad quality, worst quality"

with torch.inference_mode():
    latents = pipe(
        video=video,
        prompt=prompt,
        negative_prompt=negative_prompt,
        num_inference_steps=10,
        guidance_scale=2.0,
        controlnet_conditioning_scale=0.75,
        conditioning_frames=conditioning_frames,
        strength=strength,
        generator=torch.Generator().manual_seed(42),
        ip_adapter_image=ip_adapter_image,
        output_type="latent",
    ).frames

import torch.nn.functional as F

# Run latent upscaling
# Note that only naive upscaling is done here. Alternatively, a latent upscaler
# model could be used

batch_size, num_channels, num_frames, latent_height, latent_width = latents.shape
height = 512
width = 512
scale_factor = 1
scale_method = "nearest-exact"
upscaled_height = int(height * scale_factor)
upscaled_width = int(width * scale_factor)
upscaled_latent_height = int(latent_height * scale_factor)
upscaled_latent_width = int(latent_width * scale_factor)
strength = 0.5

upscaled_latents = []
for i in range(batch_size):
    latent = F.interpolate(latents[i], size=(upscaled_latent_height, upscaled_latent_width), mode="nearest-exact")
    upscaled_latents.append(latent.unsqueeze(0))
upscaled_latents = torch.cat(upscaled_latents, dim=0)

# Run pipeline for denoising upscaled latents
with torch.inference_mode():
    result = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        num_inference_steps=10,
        guidance_scale=2.0,
        controlnet_conditioning_scale=0.75,
        conditioning_frames=conditioning_frames,
        strength=strength,
        generator=torch.Generator().manual_seed(42),
        ip_adapter_image=ip_adapter_image,
        output_type="pil",
        latents=upscaled_latents,
        enforce_inference_steps=True,
    ).frames[0]
   
result = [frame.resize(conditioning_frames[0].size) for frame in result]
export_to_gif(result, "latent_upscaled.gif", fps=8)

latent_upscaled

@reallyigor reallyigor requested a review from a-r-r-o-w September 2, 2024 19:30
Copy link
Contributor

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. cc @DN6 if you're free to give this a look

@reallyigor
Copy link
Contributor Author

I've updated the links to the input video and added an example output to the docs 🥳🥳

@reallyigor reallyigor requested a review from a-r-r-o-w September 3, 2024 07:17
@a-r-r-o-w
Copy link
Contributor

a-r-r-o-w commented Sep 3, 2024

Could you run make style and make fix-copies so that the style tests pass? I would be careful with make fix-copies because if # Copied from has been used incorrectly by mistake somewhere, it will overwrite your changes.

@reallyigor
Copy link
Contributor Author

reallyigor commented Sep 3, 2024

Sure, I used to run it before every commit, but forgot last time. Done ✅

@a-r-r-o-w
Copy link
Contributor

A couple of failing tests need to be addressed here before merge:

FAILED tests/pipelines/animatediff/test_animatediff_video2video_controlnet.py::AnimateDiffVideoToVideoControlNetPipelineFastTests::test_ip_adapter_single - AttributeError: 'super' object has no attribute 'test_ip_adapter_single'
  • The failing test here is unrelated to the PR. If you update branch with main, I can restart the test run so this passes

@reallyigor
Copy link
Contributor Author

It wasn't easy to understand what value to use for the expected_pipe_slice in the test_ip_adapter, but I figured it out :)
@a-r-r-o-w

@a-r-r-o-w
Copy link
Contributor

Hey, this is looking good. The failing test is due to the tests being run on different machine types. To get the correct numbers, we'd have to get them from the specific CPU runners we use. I can get them to you some time tomorrow or over the weekend since caught up with other things at the moment

@reallyigor
Copy link
Contributor Author

To get the correct numbers, we'd have to get them from the specific CPU runners we use. I can get them to you some time tomorrow or over the weekend since caught up with other things at the moment

Thank you <3, I appreciate it!

@reallyigor
Copy link
Contributor Author

Offtop: Isn't it weird that CPU tests depend on the CPU type?
This test passes on my machine :/

@reallyigor
Copy link
Contributor Author

@a-r-r-o-w
please take a look when you have a moment 🙏

@a-r-r-o-w
Copy link
Contributor

hey, sorry for the delay! here you go:

[0.5569  0.6250   0.4144 0.5613 0.5563  0.5213
 0.5091   0.4950 0.4950 0.5684 0.3858  0.4863
 0.6457  0.4311 0.5517 0.5608  0.4417  0.5377 ]

@reallyigor
Copy link
Contributor Author

Thanks a lot, Aryan!
I updated the test ✅

Copy link
Contributor

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool PR @reallyigor! Everything was correct for the most part from the very beginning, which is quite rare to come across so really great work :)

@reallyigor
Copy link
Contributor Author

Very cool PR @reallyigor! Everything was correct for the most part from the very beginning, which is quite rare to come across so really great work :)

Thank you for your kind words 😍 , I really appreciate it

@reallyigor
Copy link
Contributor Author

What are the next steps for merging?

@a-r-r-o-w
Copy link
Contributor

Nothing much, it looks great! I'm just waiting to check if @DN6 would like to give this a review too by Monday since it was his ask initially. Happy to merge by EOD tomorrow even if he isn't able to take a look because seems like results are as expected :)

@reallyigor
Copy link
Contributor Author

🥳

@reallyigor
Copy link
Contributor Author

@DN6
Do you mind taking a look when you have a moment? 🙏

@a-r-r-o-w a-r-r-o-w merged commit a7361dc into huggingface:main Sep 9, 2024
@a-r-r-o-w
Copy link
Contributor

Glad to have you as a first-time-contributor! 🥳

@HanLiii
Copy link

HanLiii commented Sep 21, 2024

@a-r-r-o-w @reallyigor Thank you for the great work!
My diffusers lib is V0.30.3, I tried to import AnimateDiffVideoToVideoPipeline, but it seems that AnimateDiffVideoToVideoPipeline doesn't exist in diffusers lib V0.30.3

from diffusers import AnimateDiffVideoToVideoControlNetPipeline
ImportError: cannot import name 'AnimateDiffVideoToVideoControlNetPipeline' from 'diffusers' (/data/hli358/envs/animatediff/lib/python3.10/site-packages/diffusers/init.py)

@a-r-r-o-w
Copy link
Contributor

AnimateDiffVideoToVideoPipeline doesn't exist in diffusers lib V0.30.3

This pipeline was not shipped with 0.30.3. The only changes in that release were CogVideoX vid2vid and img2vid.

For now, you will have to install diffusers from the main branch to be able to use this pipeline. So, pip install git+https://github.com/huggingface/diffusers should fix this error hopefully.

@HanLiii
Copy link

HanLiii commented Sep 25, 2024

@a-r-r-o-w, Thank you for the help! I wonder if the AnimateDiffVideoToVideoControlNetPipeline can support the tile controlnet from [,](https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile), similar to how it supports the OpenPose ControlNet in the example:

controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16
)

If it currently doesn't support this, is there is a way to modify the pipeline code to add support for it?

@a-r-r-o-w
Copy link
Contributor

Our ControlNetModel supports all the available controlnets I think. cc @asomoza in case I'm wrong.

I believe it should work smoothly but if you're facing any errors with loading/inference, feel free to open a new issue and we can try and help there (so that others from the diffusers team can take a look too)

@HanLiii
Copy link

HanLiii commented Sep 25, 2024

Thank you for the clarification! The tile ControlNet loads smoothly into the pipeline, but it seems that the tile ControlNet isn’t having the expected effect in the output. Below are the results for comparison:

•	Tile ControlNet output (using script from lllyasviel/control_v11f1e_sd15_tile):

TileControlnetOutput

•	AnimateDiffVideoToVideoControlNetPipeline with Tile ControlNet output, guidance_scale=7.5, LCMScheduler:

AnimateDiffVideoToVideoControlNetPipeline_TileControlnet

•	AnimateDiffVideoToVideoControlNetPipeline with OpenPose ControlNet output:

AnimateDiffVideoToVideoControlNetPipeline_OpenPose

For context, the second and third test uses the input video:
https://github.com/user-attachments/assets/34f0aa12-c0b6-4b62-a43d-d8aaa7676c70

The first test uses the first frame of the input video, and both tests use the same prompt: “best quality, astronaut in space, dancing.” The output from the AnimateDiffVideoToVideoControlNetPipeline with Tile ControlNet appears noticeably more blurry when using the tile ControlNet and AnimateDiffVideoToVideoControlNetPipeline with OpenPose ControlNet. Could you provide any guidance on how to resolve this issue?

@asomoza
Copy link
Member

asomoza commented Sep 25, 2024

yes, we support all available controlnets for SD 1.5. Personally I haven't tested the SD 1.5 Tile one, not even with a single image.

To understand better, the comparison you're doing is between the single image and with animatediff right? I've never seen someone using the Tile for animations, maybe you can test it with something like ComfyUI to see if we have something wrong or it's just that the Tile + AnimateDiff combination doesn't work.

@a-r-r-o-w
Copy link
Contributor

The AnimateDiff motion adapters are known to have sort of a blurring effect and poorer quality when it comes to following image/video condition. I can do some tests soon when I find time to help you more on this.

Just curious, do you notice this behaviour when using AnimateLCM with tile controlnet? If not, it might be because the original motion adapters are not the best when it comes to high resolution animation quality (even with a controlnet). Typically, you have to involve more tricks like latent upscaling, unsampling, adetailer, etc. for good results.

@reallyigor
Copy link
Contributor Author

reallyigor commented Sep 25, 2024

Hey @HanLiii ,
Based on the results you attached cfg scale is probably set too high -> the image looks oversaturated.
+ I would definitely try using the original noise scheduler, not an LCM (LCM tends to blur because it has much less denoising steps)

@HanLiii
Copy link

HanLiii commented Sep 25, 2024

@reallyigor @a-r-r-o-w @asomoza, thanks for the reply!

In my earlier comment, I included outputs from the AnimateDiffVideoToVideoControlNetPipeline using the OpenPose ControlNet. The results were vivid and detailed, which was encouraging.

However, when I use the same pipeline with the Tile ControlNet, the outputs appear blurry and lack the vivid details present in the OpenPose results.

To troubleshoot, I tried:

Adjusted guidance_scale from 7.5 down to 3.0.
Replaced LCMScheduler with DDIMScheduler.
Removed the motion module to see if it affected the output quality.

Despite these changes, the outputs remain blurry. Here’s an example using Tile ControlNet with guidance_scale=3.0 and DDIMScheduler:

•	AnimateDiffVideoToVideoControlNetPipeline with Tile ControlNet output, gudiance_scale=3.0, DDIMScheduler:

DDIMScheduler_guidance3 0

sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
* add animatediff + vid2vide + controlnet

* post tests fixes

* PR discussion fixes

* update docs

* change input video to links on HF + update an example

* make quality fix

* fix ip adapter test

* fix ip adapter test input

* update ip adapter test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants