Skip to content

[MS Text To Video] Add first text to video #2738

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 48 commits into from
Mar 22, 2023
Merged

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Mar 19, 2023

This PR adds the text-to-video model from model scope: https://modelscope.cn/models/damo/text-to-video-synthesis/summary

Also see: https://www.reddit.com/r/StableDiffusion/comments/11vbyei/first_open_source_text_to_video_17_billion/

The model consists of three componests:

  • Text encoder: The same as Stable Diffusion 2.1
  • UNet3D: Structure looks quite similar to SD's UNet
  • Latent Upscaler is the same as Stable Diffusion 2.1

Simple command to run the model

import torch
from diffusers import TextToVideoMSPipeline, DPMSolverMultistepScheduler
from diffusers.utils import export_to_video

pipe = TextToVideoMSPipeline.from_pretrained("diffusers/ms-text-to-video-sd", torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")

prompt = "Spiderman is surfing"
video_frames = pipe(prompt).frames
video_path = export_to_video(video_frames)
print(video_path)

To reproduce results compared to original model:

import cv2
import tempfile
from huggingface_hub import HfApi
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
import torch
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline

seed = 0
api = HfApi()

prompt = "An astronaut riding a horse"


def write_video(video):
    output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    h, w, c = video[0].shape
    video_writer = cv2.VideoWriter(
        output_video_path, fourcc, fps=8, frameSize=(w, h))
    for i in range(len(video)):
        img = cv2.cvtColor(video[i], cv2.COLOR_RGB2BGR)
        video_writer.write(img)
    return output_video_path


pipe = pipeline('text-to-video-synthesis', "/home/patrick_huggingface_co/text_to_video_model_scope/weights")
torch.manual_seed(seed)
video_path = pipe({'text': prompt})[OutputKeys.OUTPUT_VIDEO]

api.upload_file(
    path_or_fileobj=video_path,
    path_in_repo="video_orig.mp4",
    repo_id="patrickvonplaten/videos",
    repo_type="dataset",
)
del pipe
print("https://huggingface.co/datasets/patrickvonplaten/videos/blob/main/video_orig.mp4")

torch.cuda.empty_cache()

pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", variant="fp16", torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()

generator = torch.manual_seed(seed)
latents = torch.randn((1, 4, 16, 32, 32)).half()
video = pipe(prompt, latents=latents, num_inference_steps=25)[0]

video_path = write_video(video)
api.upload_file(
    path_or_fileobj=video_path,
    path_in_repo="video_diff.mp4",
    repo_id="patrickvonplaten/videos",
    repo_type="dataset",
)
print("https://huggingface.co/datasets/patrickvonplaten/videos/blob/main/video_diff.mp4")

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 19, 2023

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten patrickvonplaten changed the title [MS Text To Video} Add first text to video [MS Text To Video] Add first text to video Mar 19, 2023
@sayakpaul
Copy link
Member

@patrickvonplaten I checked the implementation of AutoencoderKL in ModelScope and I verified that with what we have in diffusers.

The implementations are functionally same (-- check the Colab).

We should (99% likely) have to just convert the parameters.

@sayakpaul
Copy link
Member

@patrickvonplaten here is a Colab Notebook that shows the minor changes needed to make it work with create_vae_diffusers_config() and convert_ldm_vae_checkpoint(). Might be useful.

https://colab.research.google.com/gist/sayakpaul/f1b55ebcd2c850fcdeda351f3a4599e8/scratchpad.ipynb

@sayakpaul
Copy link
Member

@patrickvonplaten I uploaded it here: https://huggingface.co/diffusers/ms-text-to-video-1.7b/tree/main/vae

Here's the Colab Notebook I used: https://colab.research.google.com/gist/sayakpaul/930d6f582e4c5e381db1b392b479141b/scratchpad.ipynb

One weird thing is that after getting the model checkpoints converted from ModelScope to Diffusers I am seeing a reduction in the parameter size.

https://huggingface.co/damo-vilab/modelscope-damo-text-to-video-synthesis/blob/main/VQGAN_autoencoder.pth (original VAE checkpoints) is ~5 GB.

Ours (https://huggingface.co/diffusers/ms-text-to-video-1.7b/blob/main/vae/diffusion_pytorch_model.bin) is 335 MB.

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Mar 21, 2023

Pipeline works. Seems to work also with other schedulers and fp16. Can run with just 7GB of memory using Torch2.0:

#!/usr/bin/env python3
import cv2
import tempfile
from huggingface_hub import HfApi
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
import torch


api = HfApi()


def write_video(video):
    output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    h, w, c = video[0].shape
    video_writer = cv2.VideoWriter(
        output_video_path, fourcc, fps=8, frameSize=(w, h))
    for i in range(len(video)):
        img = cv2.cvtColor(video[i], cv2.COLOR_RGB2BGR)
        video_writer.write(img)
    return output_video_path


pipe = DiffusionPipeline.from_pretrained("diffusers/ms-text-to-video-sd", variant="fp16, torch_device=torch.float16)                                                                                                                   
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
video = pipe("Spiderman is surfing", num_inference_steps=25).image

video_path = write_video(video)

api.upload_file(
    path_or_fileobj=video_path,
    path_in_repo="video.mp4",
    repo_id="patrickvonplaten/videos",
    repo_type="dataset",
)
print("https://huggingface.co/datasets/patrickvonplaten/videos/blob/main/video.mp4")

@patrickvonplaten
Copy link
Contributor Author

Can generate up to 8 seconds on V100 thanks to vae slicing:

pipe.enable_model_cpu_offload()
pipe.enable_vae_slicing()
video = pipe("Darth Vader surfing a wave", num_frames=64, num_inference_steps=25).image

@@ -0,0 +1,667 @@
# Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add alibaba citation

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not a big deal, but the pipeline code is arguably mostly HF :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah true, removed it there since there is no code really from alibiba except the tensor2vid which is tiny and whree we left a link and comment

@@ -0,0 +1,492 @@
# Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add alibaba copy right

@@ -1,3 +1,18 @@
# Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add alibaba copy right

class TemporalConvLayer(nn.Module):
"""
Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment that it's copied code

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing work! Just pointed out a few minor questions, but this is totally good to go imo.

Comment on lines +108 to +112
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to be always skipped in our current CI, I think (it uses PyTorch 2 unless I'm mistaken)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we should maybe clean this up soon


assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

def test_stable_diffusion_pix2pix_negative_prompt(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does pix2pix work? 😮

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah not sure - removed this one 😅

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should actually work. Just needed a renaming.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we pop it back in?

Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
return_dict (`bool`, *optional*, defaults to `True`):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output_type missing in the docstring.

Comment on lines +44 to +47
>>> pipe = TextToVideoSDPipeline.from_pretrained(
... "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16"
... )
>>> pipe.enable_model_cpu_offload()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We had previously talked about example snippets standardizing on enabling all optimizations, including UniPCMultistepScheduler, xformers, and 20 steps. Is that worth adding here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good point. I'm somewhat assuming people use torch 2.0 now so no need anymore for exformers. UniPC doesn't work well with the model, but DPM works well. We should maybe add it there

Comment on lines +23 to +27
def to_np(tensor):
if isinstance(tensor, torch.Tensor):
tensor = tensor.detach().cpu().numpy()

return tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason we need to convert to numpy arrays here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Text to video returns PyTorch tensors

Copy link
Contributor

@williamberman williamberman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@patrickvonplaten
Copy link
Contributor Author

Tests passing locally, merging now

@patrickvonplaten patrickvonplaten merged commit ca1a222 into main Mar 22, 2023
@patrickvonplaten patrickvonplaten deleted the text_to_video branch March 22, 2023 17:39
@kabachuha
Copy link
Contributor

Awesome! Thank you for all the great work! Dreambooth is the next step, I guess

w4ffl35 pushed a commit to w4ffl35/diffusers that referenced this pull request Apr 14, 2023
* [MS Text To Video} Add first text to video

* upload

* make first model example

* match unet3d params

* make sure weights are correcctly converted

* improve

* forward pass works, but diff result

* make forward work

* fix more

* finish

* refactor video output class.

* feat: add support for a video export utility.

* fix: opencv availability check.

* run make fix-copies.

* add: docs for the model components.

* add: standalone pipeline doc.

* edit docstring of the pipeline.

* add: right path to TransformerTempModel

* add: first set of tests.

* complete fast tests for text to video.

* fix bug

* up

* three fast tests failing.

* add: note on slow tests

* make work with all schedulers

* apply styling.

* add slow tests

* change file name

* update

* more correction

* more fixes

* finish

* up

* Apply suggestions from code review

* up

* finish

* make copies

* fix pipeline tests

* fix more tests

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <[email protected]>

* apply suggestions

* up

* revert

---------

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* [MS Text To Video} Add first text to video

* upload

* make first model example

* match unet3d params

* make sure weights are correcctly converted

* improve

* forward pass works, but diff result

* make forward work

* fix more

* finish

* refactor video output class.

* feat: add support for a video export utility.

* fix: opencv availability check.

* run make fix-copies.

* add: docs for the model components.

* add: standalone pipeline doc.

* edit docstring of the pipeline.

* add: right path to TransformerTempModel

* add: first set of tests.

* complete fast tests for text to video.

* fix bug

* up

* three fast tests failing.

* add: note on slow tests

* make work with all schedulers

* apply styling.

* add slow tests

* change file name

* update

* more correction

* more fixes

* finish

* up

* Apply suggestions from code review

* up

* finish

* make copies

* fix pipeline tests

* fix more tests

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <[email protected]>

* apply suggestions

* up

* revert

---------

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* [MS Text To Video} Add first text to video

* upload

* make first model example

* match unet3d params

* make sure weights are correcctly converted

* improve

* forward pass works, but diff result

* make forward work

* fix more

* finish

* refactor video output class.

* feat: add support for a video export utility.

* fix: opencv availability check.

* run make fix-copies.

* add: docs for the model components.

* add: standalone pipeline doc.

* edit docstring of the pipeline.

* add: right path to TransformerTempModel

* add: first set of tests.

* complete fast tests for text to video.

* fix bug

* up

* three fast tests failing.

* add: note on slow tests

* make work with all schedulers

* apply styling.

* add slow tests

* change file name

* update

* more correction

* more fixes

* finish

* up

* Apply suggestions from code review

* up

* finish

* make copies

* fix pipeline tests

* fix more tests

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <[email protected]>

* apply suggestions

* up

* revert

---------

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants