Skip to content

Text2video zero refinements #3070

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from dataclasses import dataclass
from typing import Callable, List, Optional, Union

Expand Down Expand Up @@ -56,8 +57,8 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma
is_cross_attention = encoder_hidden_states is not None
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
Expand Down Expand Up @@ -285,7 +286,8 @@ def backward_loop(
latents: latents of backward process output at time timesteps[-1]
"""
do_classifier_free_guidance = guidance_scale > 1.0
with self.progress_bar(total=len(timesteps)) as progress_bar:
num_steps = (len(timesteps) - num_warmup_steps) // self.scheduler.order
with self.progress_bar(total=num_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
Expand Down Expand Up @@ -466,6 +468,7 @@ def __call__(
extra_step_kwargs=extra_step_kwargs,
num_warmup_steps=num_warmup_steps,
)
scheduler_copy = copy.deepcopy(self.scheduler)

# Perform the second backward process up to time T_0
x_1_t0 = self.backward_loop(
Expand All @@ -476,7 +479,7 @@ def __call__(
callback=callback,
callback_steps=callback_steps,
extra_step_kwargs=extra_step_kwargs,
num_warmup_steps=num_warmup_steps,
num_warmup_steps=0,
)

# Propagate first frame latents at time T_0 to remaining frames
Expand All @@ -503,7 +506,7 @@ def __call__(
b, l, d = prompt_embeds.size()
prompt_embeds = prompt_embeds[:, None].repeat(1, video_length, 1, 1).reshape(b * video_length, l, d)

self.scheduler.set_timesteps(num_inference_steps, device=device)
self.scheduler = scheduler_copy
x_1k_0 = self.backward_loop(
timesteps=timesteps[-t1 - 1 :],
prompt_embeds=prompt_embeds,
Expand All @@ -512,7 +515,7 @@ def __call__(
callback=callback,
callback_steps=callback_steps,
extra_step_kwargs=extra_step_kwargs,
num_warmup_steps=num_warmup_steps,
num_warmup_steps=0,
)
latents = x_1k_0

Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
load_hf_numpy,
load_image,
load_numpy,
load_pt,
nightly,
parse_flag_from_env,
print_tensor_test,
Expand Down
6 changes: 3 additions & 3 deletions tests/pipelines/text_to_video/test_text_to_video_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch

from diffusers import DDIMScheduler, TextToVideoZeroPipeline
from diffusers.utils import require_torch_gpu, slow
from diffusers.utils import load_pt, require_torch_gpu, slow

from ...test_pipelines_common import assert_mean_pixel_difference

Expand All @@ -35,8 +35,8 @@ def test_full_model(self):
prompt = "A bear is playing a guitar on Times Square"
result = pipe(prompt=prompt, generator=generator).images

expected_result = torch.load(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/tree/main/text-to-video/A bear is playing a guitar on Times Square.pt"
expected_result = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text-to-video/A bear is playing a guitar on Times Square.pt"
)

assert_mean_pixel_difference(result, expected_result)