diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 0fb4637dab7f..8d5054ab8325 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -67,10 +67,7 @@ """ -def tensor2vid(video: torch.Tensor, processor, output_type="np"): - # Based on: - # https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 - +def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): batch_size, channels, num_frames, height, width = video.shape outputs = [] for batch_idx in range(batch_size): @@ -79,6 +76,15 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"): outputs.append(batch_output) + if output_type == "np": + outputs = np.stack(outputs) + + elif output_type == "pt": + outputs = torch.stack(outputs) + + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") + return outputs @@ -805,11 +811,7 @@ def _retrieve_video_frames(self, latents, output_type, return_dict): return AnimateDiffPipelineOutput(frames=latents) video_tensor = self.decode_latents(latents) - - if output_type == "pt": - video = video_tensor - else: - video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py index 56f72691303d..9bfced85955b 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -40,10 +40,8 @@ def _append_dims(x, target_dims): return x[(...,) + (None,) * dims_to_append] -def tensor2vid(video: torch.Tensor, processor, output_type="np"): - # Based on: - # https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 - +# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid +def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): batch_size, channels, num_frames, height, width = video.shape outputs = [] for batch_idx in range(batch_size): @@ -53,7 +51,13 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"): outputs.append(batch_output) if output_type == "np": - return np.stack(outputs) + outputs = np.stack(outputs) + + elif output_type == "pt": + outputs = torch.stack(outputs) + + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") return outputs diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index ab5286a5e5b4..6e5db85c9e66 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -19,6 +19,7 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer +from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet3DConditionModel from ...models.lora import adjust_lora_scale_text_encoder @@ -58,22 +59,26 @@ """ -def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]: - # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 - # reshape to ncfhw - mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) - std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) - # unnormalize back to [0,1] - video = video.mul_(std).add_(mean) - video.clamp_(0, 1) - # prepare the final outputs - i, c, f, h, w = video.shape - images = video.permute(2, 3, 0, 4, 1).reshape( - f, h, i * w, c - ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c) - images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames) - images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c - return images +# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid +def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): + batch_size, channels, num_frames, height, width = video.shape + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = processor.postprocess(batch_vid, output_type) + + outputs.append(batch_output) + + if output_type == "np": + outputs = np.stack(outputs) + + elif output_type == "pt": + outputs = torch.stack(outputs) + + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") + + return outputs class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): @@ -122,6 +127,7 @@ def __init__( scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): @@ -717,11 +723,7 @@ def __call__( return TextToVideoSDPipelineOutput(frames=latents) video_tensor = self.decode_latents(latents) - - if output_type == "pt": - video = video_tensor - else: - video = tensor2vid(video_tensor) + video = tensor2vid(video_tensor, self.image_processor, output_type) # Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py index b19ccee660e2..c781e490caae 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -20,6 +20,7 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer +from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet3DConditionModel from ...models.lora import adjust_lora_scale_text_encoder @@ -93,22 +94,26 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]: - # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 - # reshape to ncfhw - mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) - std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) - # unnormalize back to [0,1] - video = video.mul_(std).add_(mean) - video.clamp_(0, 1) - # prepare the final outputs - i, c, f, h, w = video.shape - images = video.permute(2, 3, 0, 4, 1).reshape( - f, h, i * w, c - ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c) - images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames) - images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c - return images +# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid +def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): + batch_size, channels, num_frames, height, width = video.shape + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = processor.postprocess(batch_vid, output_type) + + outputs.append(batch_output) + + if output_type == "np": + outputs = np.stack(outputs) + + elif output_type == "pt": + outputs = torch.stack(outputs) + + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") + + return outputs def preprocess_video(video): @@ -198,6 +203,7 @@ def __init__( scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): @@ -812,12 +818,11 @@ def __call__( if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") - video_tensor = self.decode_latents(latents) + if output_type == "latent": + return TextToVideoSDPipelineOutput(frames=latents) - if output_type == "pt": - video = video_tensor - else: - video = tensor2vid(video_tensor) + video_tensor = self.decode_latents(latents) + video = tensor2vid(video_tensor, self.image_processor, output_type) # Offload all models self.maybe_free_model_hooks() diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 44cb730a9501..80a8fd19f5a0 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -262,7 +262,7 @@ def test_free_init(self): sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum() max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_init)).max() self.assertGreater( - sum_enabled, 1e2, "Enabling of FreeInit should lead to results different from the default pipeline results" + sum_enabled, 1e1, "Enabling of FreeInit should lead to results different from the default pipeline results" ) self.assertLess( max_diff_disabled, diff --git a/tests/pipelines/text_to_video_synthesis/test_text_to_video.py b/tests/pipelines/text_to_video_synthesis/test_text_to_video.py index e9f435239c92..2f48dc5c318a 100644 --- a/tests/pipelines/text_to_video_synthesis/test_text_to_video.py +++ b/tests/pipelines/text_to_video_synthesis/test_text_to_video.py @@ -29,6 +29,7 @@ from diffusers.utils.testing_utils import ( enable_full_determinism, load_numpy, + numpy_cosine_similarity_distance, require_torch_gpu, skip_mps, slow, @@ -141,10 +142,11 @@ def test_text_to_video_default_case(self): inputs = self.get_dummy_inputs(device) inputs["output_type"] = "np" frames = sd_pipe(**inputs).frames - image_slice = frames[0][-3:, -3:, -1] - assert frames[0].shape == (32, 32, 3) - expected_slice = np.array([192.0, 44.0, 157.0, 140.0, 108.0, 104.0, 123.0, 144.0, 129.0]) + image_slice = frames[0][0][-3:, -3:, -1] + + assert frames[0][0].shape == (32, 32, 3) + expected_slice = np.array([0.7537, 0.1752, 0.6157, 0.5508, 0.4240, 0.4110, 0.4838, 0.5648, 0.5094]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -183,7 +185,7 @@ def test_progress_bar(self): class TextToVideoSDPipelineSlowTests(unittest.TestCase): def test_two_step_model(self): expected_video = load_numpy( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_to_video/video_2step.npy" + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text-to-video/video_2step.npy" ) pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b") @@ -192,10 +194,8 @@ def test_two_step_model(self): prompt = "Spiderman is surfing" generator = torch.Generator(device="cpu").manual_seed(0) - video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="pt").frames - video = video_frames.cpu().numpy() - - assert np.abs(expected_video - video).mean() < 5e-2 + video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").frames + assert numpy_cosine_similarity_distance(expected_video.flatten(), video_frames.flatten()) < 1e-4 def test_two_step_model_with_freeu(self): expected_video = [] @@ -207,10 +207,9 @@ def test_two_step_model_with_freeu(self): generator = torch.Generator(device="cpu").manual_seed(0) pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4) - video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="pt").frames - video = video_frames.cpu().numpy() - video = video[0, 0, -3:, -3:, -1].flatten() + video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").frames + video = video_frames[0, 0, -3:, -3:, -1].flatten() - expected_video = [-0.3102, -0.2477, -0.1772, -0.648, -0.6176, -0.5484, -0.0217, -0.056, -0.0177] + expected_video = [0.3643, 0.3455, 0.3831, 0.3923, 0.2978, 0.3247, 0.3278, 0.3201, 0.3475] assert np.abs(expected_video - video).mean() < 5e-2 diff --git a/tests/pipelines/text_to_video_synthesis/test_video_to_video.py b/tests/pipelines/text_to_video_synthesis/test_video_to_video.py index 1785eb967f16..07d48eba5574 100644 --- a/tests/pipelines/text_to_video_synthesis/test_video_to_video.py +++ b/tests/pipelines/text_to_video_synthesis/test_video_to_video.py @@ -157,10 +157,10 @@ def test_text_to_video_default_case(self): inputs = self.get_dummy_inputs(device) inputs["output_type"] = "np" frames = sd_pipe(**inputs).frames - image_slice = frames[0][-3:, -3:, -1] + image_slice = frames[0][0][-3:, -3:, -1] - assert frames[0].shape == (32, 32, 3) - expected_slice = np.array([162.0, 136.0, 132.0, 140.0, 139.0, 137.0, 169.0, 134.0, 132.0]) + assert frames[0][0].shape == (32, 32, 3) + expected_slice = np.array([0.6391, 0.5350, 0.5202, 0.5521, 0.5453, 0.5393, 0.6652, 0.5270, 0.5185]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -214,9 +214,11 @@ def test_two_step_model(self): prompt = "Spiderman is surfing" - video_frames = pipe(prompt, video=video, generator=generator, num_inference_steps=3, output_type="pt").frames - - expected_array = np.array([-0.9770508, -0.8027344, -0.62646484, -0.8334961, -0.7573242]) - output_array = video_frames.cpu().numpy()[0, 0, 0, 0, -5:] + generator = torch.Generator(device="cpu").manual_seed(0) + video_frames = pipe(prompt, video=video, generator=generator, num_inference_steps=3, output_type="np").frames - assert numpy_cosine_similarity_distance(expected_array, output_array) < 1e-2 + expected_array = np.array( + [0.17114258, 0.13720703, 0.08886719, 0.14819336, 0.1730957, 0.24584961, 0.22021484, 0.35180664, 0.2607422] + ) + output_array = video_frames[0, 0, :3, :3, 0].flatten() + assert numpy_cosine_similarity_distance(expected_array, output_array) < 1e-3