-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[MS Text To Video] Add first text to video #2738
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The documentation is not available anymore as the PR was closed or merged. |
@patrickvonplaten I checked the implementation of AutoencoderKL in ModelScope and I verified that with what we have in diffusers. The implementations are functionally same (-- check the Colab). We should (99% likely) have to just convert the parameters. |
@patrickvonplaten here is a Colab Notebook that shows the minor changes needed to make it work with https://colab.research.google.com/gist/sayakpaul/f1b55ebcd2c850fcdeda351f3a4599e8/scratchpad.ipynb |
@patrickvonplaten I uploaded it here: https://huggingface.co/diffusers/ms-text-to-video-1.7b/tree/main/vae Here's the Colab Notebook I used: https://colab.research.google.com/gist/sayakpaul/930d6f582e4c5e381db1b392b479141b/scratchpad.ipynb One weird thing is that after getting the model checkpoints converted from ModelScope to Diffusers I am seeing a reduction in the parameter size. https://huggingface.co/damo-vilab/modelscope-damo-text-to-video-synthesis/blob/main/VQGAN_autoencoder.pth (original VAE checkpoints) is ~5 GB. Ours (https://huggingface.co/diffusers/ms-text-to-video-1.7b/blob/main/vae/diffusion_pytorch_model.bin) is 335 MB. |
Pipeline works. Seems to work also with other schedulers and fp16. Can run with just 7GB of memory using Torch2.0: #!/usr/bin/env python3
import cv2
import tempfile
from huggingface_hub import HfApi
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
import torch
api = HfApi()
def write_video(video):
output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
h, w, c = video[0].shape
video_writer = cv2.VideoWriter(
output_video_path, fourcc, fps=8, frameSize=(w, h))
for i in range(len(video)):
img = cv2.cvtColor(video[i], cv2.COLOR_RGB2BGR)
video_writer.write(img)
return output_video_path
pipe = DiffusionPipeline.from_pretrained("diffusers/ms-text-to-video-sd", variant="fp16, torch_device=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
video = pipe("Spiderman is surfing", num_inference_steps=25).image
video_path = write_video(video)
api.upload_file(
path_or_fileobj=video_path,
path_in_repo="video.mp4",
repo_id="patrickvonplaten/videos",
repo_type="dataset",
)
print("https://huggingface.co/datasets/patrickvonplaten/videos/blob/main/video.mp4") |
Can generate up to 8 seconds on V100 thanks to vae slicing: pipe.enable_model_cpu_offload()
pipe.enable_vae_slicing()
video = pipe("Darth Vader surfing a wave", num_frames=64, num_inference_steps=25).image |
src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py
Show resolved
Hide resolved
src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py
Show resolved
Hide resolved
@@ -0,0 +1,667 @@ | |||
# Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add alibaba citation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably not a big deal, but the pipeline code is arguably mostly HF :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah true, removed it there since there is no code really from alibiba except the tensor2vid which is tiny and whree we left a link and comment
@@ -0,0 +1,492 @@ | |||
# Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add alibaba copy right
src/diffusers/models/resnet.py
Outdated
@@ -1,3 +1,18 @@ | |||
# Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add alibaba copy right
class TemporalConvLayer(nn.Module): | ||
""" | ||
Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from: | ||
https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comment that it's copied code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing work! Just pointed out a few minor questions, but this is totally good to go imo.
src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py
Show resolved
Hide resolved
@unittest.skipIf( | ||
torch_device != "cuda" or not is_xformers_available(), | ||
reason="XFormers attention is only available with CUDA and `xformers` installed", | ||
) | ||
def test_xformers_enable_works(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is going to be always skipped in our current CI, I think (it uses PyTorch 2 unless I'm mistaken)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah we should maybe clean this up soon
|
||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 | ||
|
||
def test_stable_diffusion_pix2pix_negative_prompt(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does pix2pix
work? 😮
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah not sure - removed this one 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should actually work. Just needed a renaming.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we pop it back in?
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt | ||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input | ||
argument. | ||
return_dict (`bool`, *optional*, defaults to `True`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
output_type
missing in the docstring.
Co-authored-by: Pedro Cuenca <[email protected]>
src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py
Show resolved
Hide resolved
>>> pipe = TextToVideoSDPipeline.from_pretrained( | ||
... "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16" | ||
... ) | ||
>>> pipe.enable_model_cpu_offload() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We had previously talked about example snippets standardizing on enabling all optimizations, including UniPCMultistepScheduler, xformers, and 20 steps. Is that worth adding here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah good point. I'm somewhat assuming people use torch 2.0 now so no need anymore for exformers. UniPC doesn't work well with the model, but DPM works well. We should maybe add it there
def to_np(tensor): | ||
if isinstance(tensor, torch.Tensor): | ||
tensor = tensor.detach().cpu().numpy() | ||
|
||
return tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the reason we need to convert to numpy arrays here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Text to video returns PyTorch tensors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
Tests passing locally, merging now |
Awesome! Thank you for all the great work! Dreambooth is the next step, I guess |
* [MS Text To Video} Add first text to video * upload * make first model example * match unet3d params * make sure weights are correcctly converted * improve * forward pass works, but diff result * make forward work * fix more * finish * refactor video output class. * feat: add support for a video export utility. * fix: opencv availability check. * run make fix-copies. * add: docs for the model components. * add: standalone pipeline doc. * edit docstring of the pipeline. * add: right path to TransformerTempModel * add: first set of tests. * complete fast tests for text to video. * fix bug * up * three fast tests failing. * add: note on slow tests * make work with all schedulers * apply styling. * add slow tests * change file name * update * more correction * more fixes * finish * up * Apply suggestions from code review * up * finish * make copies * fix pipeline tests * fix more tests * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * apply suggestions * up * revert --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
* [MS Text To Video} Add first text to video * upload * make first model example * match unet3d params * make sure weights are correcctly converted * improve * forward pass works, but diff result * make forward work * fix more * finish * refactor video output class. * feat: add support for a video export utility. * fix: opencv availability check. * run make fix-copies. * add: docs for the model components. * add: standalone pipeline doc. * edit docstring of the pipeline. * add: right path to TransformerTempModel * add: first set of tests. * complete fast tests for text to video. * fix bug * up * three fast tests failing. * add: note on slow tests * make work with all schedulers * apply styling. * add slow tests * change file name * update * more correction * more fixes * finish * up * Apply suggestions from code review * up * finish * make copies * fix pipeline tests * fix more tests * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * apply suggestions * up * revert --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
* [MS Text To Video} Add first text to video * upload * make first model example * match unet3d params * make sure weights are correcctly converted * improve * forward pass works, but diff result * make forward work * fix more * finish * refactor video output class. * feat: add support for a video export utility. * fix: opencv availability check. * run make fix-copies. * add: docs for the model components. * add: standalone pipeline doc. * edit docstring of the pipeline. * add: right path to TransformerTempModel * add: first set of tests. * complete fast tests for text to video. * fix bug * up * three fast tests failing. * add: note on slow tests * make work with all schedulers * apply styling. * add slow tests * change file name * update * more correction * more fixes * finish * up * Apply suggestions from code review * up * finish * make copies * fix pipeline tests * fix more tests * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * apply suggestions * up * revert --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
This PR adds the text-to-video model from model scope: https://modelscope.cn/models/damo/text-to-video-synthesis/summary
Also see: https://www.reddit.com/r/StableDiffusion/comments/11vbyei/first_open_source_text_to_video_17_billion/
The model consists of three componests:
Simple command to run the model
To reproduce results compared to original model: