diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py index 6cf4b8544b01..41687c7e053c 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py @@ -1,3 +1,4 @@ +import copy from dataclasses import dataclass from typing import Callable, List, Optional, Union @@ -56,8 +57,8 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma is_cross_attention = encoder_hidden_states is not None if encoder_hidden_states is None: encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -285,7 +286,8 @@ def backward_loop( latents: latents of backward process output at time timesteps[-1] """ do_classifier_free_guidance = guidance_scale > 1.0 - with self.progress_bar(total=len(timesteps)) as progress_bar: + num_steps = (len(timesteps) - num_warmup_steps) // self.scheduler.order + with self.progress_bar(total=num_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents @@ -466,6 +468,7 @@ def __call__( extra_step_kwargs=extra_step_kwargs, num_warmup_steps=num_warmup_steps, ) + scheduler_copy = copy.deepcopy(self.scheduler) # Perform the second backward process up to time T_0 x_1_t0 = self.backward_loop( @@ -476,7 +479,7 @@ def __call__( callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, - num_warmup_steps=num_warmup_steps, + num_warmup_steps=0, ) # Propagate first frame latents at time T_0 to remaining frames @@ -503,7 +506,7 @@ def __call__( b, l, d = prompt_embeds.size() prompt_embeds = prompt_embeds[:, None].repeat(1, video_length, 1, 1).reshape(b * video_length, l, d) - self.scheduler.set_timesteps(num_inference_steps, device=device) + self.scheduler = scheduler_copy x_1k_0 = self.backward_loop( timesteps=timesteps[-t1 - 1 :], prompt_embeds=prompt_embeds, @@ -512,7 +515,7 @@ def __call__( callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, - num_warmup_steps=num_warmup_steps, + num_warmup_steps=0, ) latents = x_1k_0 diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index bb159d9db375..c717d722f84c 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -86,6 +86,7 @@ load_hf_numpy, load_image, load_numpy, + load_pt, nightly, parse_flag_from_env, print_tensor_test, diff --git a/tests/pipelines/text_to_video/test_text_to_video_zero.py b/tests/pipelines/text_to_video/test_text_to_video_zero.py index e6a726bf13c5..45bb93fbd9c6 100644 --- a/tests/pipelines/text_to_video/test_text_to_video_zero.py +++ b/tests/pipelines/text_to_video/test_text_to_video_zero.py @@ -18,7 +18,7 @@ import torch from diffusers import DDIMScheduler, TextToVideoZeroPipeline -from diffusers.utils import require_torch_gpu, slow +from diffusers.utils import load_pt, require_torch_gpu, slow from ...test_pipelines_common import assert_mean_pixel_difference @@ -35,8 +35,8 @@ def test_full_model(self): prompt = "A bear is playing a guitar on Times Square" result = pipe(prompt=prompt, generator=generator).images - expected_result = torch.load( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/tree/main/text-to-video/A bear is playing a guitar on Times Square.pt" + expected_result = load_pt( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text-to-video/A bear is playing a guitar on Times Square.pt" ) assert_mean_pixel_difference(result, expected_result)