From f3420ed2d951bb76e7d219590ef1652172ce955b Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 17 Jan 2024 12:39:07 +0000 Subject: [PATCH 1/9] update --- .../animatediff/pipeline_animatediff.py | 6 ++-- .../pipeline_stable_video_diffusion.py | 4 +-- .../pipeline_text_to_video_synth.py | 34 +++++++++---------- .../test_text_to_video.py | 5 +-- 4 files changed, 24 insertions(+), 25 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 0fb4637dab7f..03e0c49241a5 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -68,9 +68,6 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"): - # Based on: - # https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 - batch_size, channels, num_frames, height, width = video.shape outputs = [] for batch_idx in range(batch_size): @@ -79,6 +76,9 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"): outputs.append(batch_output) + if output_type == "np": + outputs = np.stack(outputs) + return outputs diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py index 56f72691303d..07dddf457d62 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -40,10 +40,8 @@ def _append_dims(x, target_dims): return x[(...,) + (None,) * dims_to_append] +# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid def tensor2vid(video: torch.Tensor, processor, output_type="np"): - # Based on: - # https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 - batch_size, channels, num_frames, height, width = video.shape outputs = [] for batch_idx in range(batch_size): diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index ab5286a5e5b4..62f83c233286 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -19,6 +19,7 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer +from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet3DConditionModel from ...models.lora import adjust_lora_scale_text_encoder @@ -58,22 +59,20 @@ """ -def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]: - # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 - # reshape to ncfhw - mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) - std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) - # unnormalize back to [0,1] - video = video.mul_(std).add_(mean) - video.clamp_(0, 1) - # prepare the final outputs - i, c, f, h, w = video.shape - images = video.permute(2, 3, 0, 4, 1).reshape( - f, h, i * w, c - ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c) - images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames) - images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c - return images +# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid +def tensor2vid(video: torch.Tensor, processor, output_type="np"): + batch_size, channels, num_frames, height, width = video.shape + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = processor.postprocess(batch_vid, output_type) + + outputs.append(batch_output) + + if output_type == "np": + outputs = np.stack(outputs) + + return outputs class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): @@ -122,6 +121,7 @@ def __init__( scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): @@ -721,7 +721,7 @@ def __call__( if output_type == "pt": video = video_tensor else: - video = tensor2vid(video_tensor) + video = tensor2vid(video_tensor, self.image_processor, output_type) # Offload all models self.maybe_free_model_hooks() diff --git a/tests/pipelines/text_to_video_synthesis/test_text_to_video.py b/tests/pipelines/text_to_video_synthesis/test_text_to_video.py index e9f435239c92..4fa6bdb20116 100644 --- a/tests/pipelines/text_to_video_synthesis/test_text_to_video.py +++ b/tests/pipelines/text_to_video_synthesis/test_text_to_video.py @@ -141,9 +141,10 @@ def test_text_to_video_default_case(self): inputs = self.get_dummy_inputs(device) inputs["output_type"] = "np" frames = sd_pipe(**inputs).frames - image_slice = frames[0][-3:, -3:, -1] - assert frames[0].shape == (32, 32, 3) + image_slice = frames[0][0][-3:, -3:, -1] + + assert frames[0][0].shape == (32, 32, 3) expected_slice = np.array([192.0, 44.0, 157.0, 140.0, 108.0, 104.0, 123.0, 144.0, 129.0]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 From f2f58d1eaf7e2f524b050b673c963914d1385951 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 18 Jan 2024 07:40:08 +0000 Subject: [PATCH 2/9] update --- tests/pipelines/text_to_video_synthesis/test_text_to_video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/text_to_video_synthesis/test_text_to_video.py b/tests/pipelines/text_to_video_synthesis/test_text_to_video.py index 4fa6bdb20116..7bd56033987d 100644 --- a/tests/pipelines/text_to_video_synthesis/test_text_to_video.py +++ b/tests/pipelines/text_to_video_synthesis/test_text_to_video.py @@ -145,7 +145,7 @@ def test_text_to_video_default_case(self): image_slice = frames[0][0][-3:, -3:, -1] assert frames[0][0].shape == (32, 32, 3) - expected_slice = np.array([192.0, 44.0, 157.0, 140.0, 108.0, 104.0, 123.0, 144.0, 129.0]) + expected_slice = np.array([0.7537, 0.1752, 0.6157, 0.5508, 0.4240, 0.4110, 0.4838, 0.5648, 0.5094]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 From e601a992b61ddd7443a63518df58ef8507aa0e4b Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 18 Jan 2024 08:21:08 +0000 Subject: [PATCH 3/9] update --- .../pipelines/animatediff/pipeline_animatediff.py | 9 ++++----- .../pipeline_stable_video_diffusion.py | 5 ++++- .../pipeline_text_to_video_synth.py | 3 +++ 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 03e0c49241a5..1986d4932e1f 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -79,6 +79,9 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"): if output_type == "np": outputs = np.stack(outputs) + if output_type == "pt": + outputs = torch.stack(outputs) + return outputs @@ -805,11 +808,7 @@ def _retrieve_video_frames(self, latents, output_type, return_dict): return AnimateDiffPipelineOutput(frames=latents) video_tensor = self.decode_latents(latents) - - if output_type == "pt": - video = video_tensor - else: - video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py index 07dddf457d62..78b6a25a21b0 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -51,7 +51,10 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"): outputs.append(batch_output) if output_type == "np": - return np.stack(outputs) + outputs = np.stack(outputs) + + if output_type == "pt": + outputs = torch.stack(outputs) return outputs diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index 62f83c233286..d7c1322aed7e 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -72,6 +72,9 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"): if output_type == "np": outputs = np.stack(outputs) + if output_type == "pt": + outputs = torch.cat(outputs, dim=0) + return outputs From bd1ac23d1d6badbbfbf91c6efc12a2f2daf662a9 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 18 Jan 2024 08:29:19 +0000 Subject: [PATCH 4/9] update --- .../text_to_video_synthesis/pipeline_text_to_video_synth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index d7c1322aed7e..b665d4c89c55 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -73,7 +73,7 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"): outputs = np.stack(outputs) if output_type == "pt": - outputs = torch.cat(outputs, dim=0) + outputs = torch.stack(outputs) return outputs From 1d27c52ee0e84ee8f82d0e2af559326872fa8208 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 18 Jan 2024 09:05:50 +0000 Subject: [PATCH 5/9] update --- .../pipeline_text_to_video_synth.py | 6 +----- .../test_text_to_video.py | 16 +++++++--------- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index b665d4c89c55..680516326893 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -720,11 +720,7 @@ def __call__( return TextToVideoSDPipelineOutput(frames=latents) video_tensor = self.decode_latents(latents) - - if output_type == "pt": - video = video_tensor - else: - video = tensor2vid(video_tensor, self.image_processor, output_type) + video = tensor2vid(video_tensor, self.image_processor, output_type) # Offload all models self.maybe_free_model_hooks() diff --git a/tests/pipelines/text_to_video_synthesis/test_text_to_video.py b/tests/pipelines/text_to_video_synthesis/test_text_to_video.py index 7bd56033987d..2f48dc5c318a 100644 --- a/tests/pipelines/text_to_video_synthesis/test_text_to_video.py +++ b/tests/pipelines/text_to_video_synthesis/test_text_to_video.py @@ -29,6 +29,7 @@ from diffusers.utils.testing_utils import ( enable_full_determinism, load_numpy, + numpy_cosine_similarity_distance, require_torch_gpu, skip_mps, slow, @@ -184,7 +185,7 @@ def test_progress_bar(self): class TextToVideoSDPipelineSlowTests(unittest.TestCase): def test_two_step_model(self): expected_video = load_numpy( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_to_video/video_2step.npy" + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text-to-video/video_2step.npy" ) pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b") @@ -193,10 +194,8 @@ def test_two_step_model(self): prompt = "Spiderman is surfing" generator = torch.Generator(device="cpu").manual_seed(0) - video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="pt").frames - video = video_frames.cpu().numpy() - - assert np.abs(expected_video - video).mean() < 5e-2 + video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").frames + assert numpy_cosine_similarity_distance(expected_video.flatten(), video_frames.flatten()) < 1e-4 def test_two_step_model_with_freeu(self): expected_video = [] @@ -208,10 +207,9 @@ def test_two_step_model_with_freeu(self): generator = torch.Generator(device="cpu").manual_seed(0) pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4) - video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="pt").frames - video = video_frames.cpu().numpy() - video = video[0, 0, -3:, -3:, -1].flatten() + video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").frames + video = video_frames[0, 0, -3:, -3:, -1].flatten() - expected_video = [-0.3102, -0.2477, -0.1772, -0.648, -0.6176, -0.5484, -0.0217, -0.056, -0.0177] + expected_video = [0.3643, 0.3455, 0.3831, 0.3923, 0.2978, 0.3247, 0.3278, 0.3201, 0.3475] assert np.abs(expected_video - video).mean() < 5e-2 From b4616b0dbb2f009e4cd00e4d8ecf80c500c02bed Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 18 Jan 2024 10:24:54 +0000 Subject: [PATCH 6/9] update --- tests/pipelines/animatediff/test_animatediff.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 44cb730a9501..80a8fd19f5a0 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -262,7 +262,7 @@ def test_free_init(self): sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum() max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_init)).max() self.assertGreater( - sum_enabled, 1e2, "Enabling of FreeInit should lead to results different from the default pipeline results" + sum_enabled, 1e1, "Enabling of FreeInit should lead to results different from the default pipeline results" ) self.assertLess( max_diff_disabled, From c765e861016a5f30921f44b13dbe2df38007851c Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 22 Jan 2024 08:22:20 +0000 Subject: [PATCH 7/9] update --- .../pipelines/animatediff/pipeline_animatediff.py | 7 +++++-- .../pipeline_stable_video_diffusion.py | 5 ++++- .../pipeline_text_to_video_synth.py | 5 ++++- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 1986d4932e1f..8d5054ab8325 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -67,7 +67,7 @@ """ -def tensor2vid(video: torch.Tensor, processor, output_type="np"): +def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): batch_size, channels, num_frames, height, width = video.shape outputs = [] for batch_idx in range(batch_size): @@ -79,9 +79,12 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"): if output_type == "np": outputs = np.stack(outputs) - if output_type == "pt": + elif output_type == "pt": outputs = torch.stack(outputs) + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") + return outputs diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py index 78b6a25a21b0..cf3bb0cc1f20 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -53,9 +53,12 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"): if output_type == "np": outputs = np.stack(outputs) - if output_type == "pt": + elif output_type == "pt": outputs = torch.stack(outputs) + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") + return outputs diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index 680516326893..5c790cccc0ce 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -72,9 +72,12 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"): if output_type == "np": outputs = np.stack(outputs) - if output_type == "pt": + elif output_type == "pt": outputs = torch.stack(outputs) + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") + return outputs From 623d44818a3651f11b4b0b6816ceef3f48a75302 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 22 Jan 2024 11:41:50 +0000 Subject: [PATCH 8/9] clean up --- .../pipeline_stable_video_diffusion.py | 2 +- .../pipeline_text_to_video_synth.py | 2 +- .../pipeline_text_to_video_synth_img2img.py | 47 ++++++++++--------- .../test_video_to_video.py | 17 +++---- 4 files changed, 37 insertions(+), 31 deletions(-) diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py index cf3bb0cc1f20..9bfced85955b 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -41,7 +41,7 @@ def _append_dims(x, target_dims): # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid -def tensor2vid(video: torch.Tensor, processor, output_type="np"): +def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): batch_size, channels, num_frames, height, width = video.shape outputs = [] for batch_idx in range(batch_size): diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index 5c790cccc0ce..6e5db85c9e66 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -60,7 +60,7 @@ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid -def tensor2vid(video: torch.Tensor, processor, output_type="np"): +def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): batch_size, channels, num_frames, height, width = video.shape outputs = [] for batch_idx in range(batch_size): diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py index b19ccee660e2..c781e490caae 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -20,6 +20,7 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer +from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet3DConditionModel from ...models.lora import adjust_lora_scale_text_encoder @@ -93,22 +94,26 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]: - # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 - # reshape to ncfhw - mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) - std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) - # unnormalize back to [0,1] - video = video.mul_(std).add_(mean) - video.clamp_(0, 1) - # prepare the final outputs - i, c, f, h, w = video.shape - images = video.permute(2, 3, 0, 4, 1).reshape( - f, h, i * w, c - ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c) - images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames) - images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c - return images +# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid +def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): + batch_size, channels, num_frames, height, width = video.shape + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = processor.postprocess(batch_vid, output_type) + + outputs.append(batch_output) + + if output_type == "np": + outputs = np.stack(outputs) + + elif output_type == "pt": + outputs = torch.stack(outputs) + + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") + + return outputs def preprocess_video(video): @@ -198,6 +203,7 @@ def __init__( scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): @@ -812,12 +818,11 @@ def __call__( if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") - video_tensor = self.decode_latents(latents) + if output_type == "latent": + return TextToVideoSDPipelineOutput(frames=latents) - if output_type == "pt": - video = video_tensor - else: - video = tensor2vid(video_tensor) + video_tensor = self.decode_latents(latents) + video = tensor2vid(video_tensor, self.image_processor, output_type) # Offload all models self.maybe_free_model_hooks() diff --git a/tests/pipelines/text_to_video_synthesis/test_video_to_video.py b/tests/pipelines/text_to_video_synthesis/test_video_to_video.py index 1785eb967f16..f01bb481ac73 100644 --- a/tests/pipelines/text_to_video_synthesis/test_video_to_video.py +++ b/tests/pipelines/text_to_video_synthesis/test_video_to_video.py @@ -33,6 +33,7 @@ is_flaky, nightly, numpy_cosine_similarity_distance, + print_tensor_test, skip_mps, torch_device, ) @@ -157,10 +158,10 @@ def test_text_to_video_default_case(self): inputs = self.get_dummy_inputs(device) inputs["output_type"] = "np" frames = sd_pipe(**inputs).frames - image_slice = frames[0][-3:, -3:, -1] + image_slice = frames[0][0][-3:, -3:, -1] - assert frames[0].shape == (32, 32, 3) - expected_slice = np.array([162.0, 136.0, 132.0, 140.0, 139.0, 137.0, 169.0, 134.0, 132.0]) + assert frames[0][0].shape == (32, 32, 3) + expected_slice = np.array([0.6391, 0.5350, 0.5202, 0.5521, 0.5453, 0.5393, 0.6652, 0.5270, 0.5185]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -214,9 +215,9 @@ def test_two_step_model(self): prompt = "Spiderman is surfing" - video_frames = pipe(prompt, video=video, generator=generator, num_inference_steps=3, output_type="pt").frames - - expected_array = np.array([-0.9770508, -0.8027344, -0.62646484, -0.8334961, -0.7573242]) - output_array = video_frames.cpu().numpy()[0, 0, 0, 0, -5:] + generator = torch.Generator(device="cpu").manual_seed(0) + video_frames = pipe(prompt, video=video, generator=generator, num_inference_steps=3, output_type="np").frames - assert numpy_cosine_similarity_distance(expected_array, output_array) < 1e-2 + expected_array = np.array([0.17114258, 0.13720703, 0.08886719, 0.14819336, 0.1730957, 0.24584961, 0.22021484, 0.35180664, 0.2607422]) + output_array = video_frames[0, 0, :3, :3, 0].flatten() + assert numpy_cosine_similarity_distance(expected_array, output_array) < 1e-3 From 341f48da0d43ec6d75d1f29658e3f278c38dbcfd Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 22 Jan 2024 11:46:36 +0000 Subject: [PATCH 9/9] clean up --- .../pipelines/text_to_video_synthesis/test_video_to_video.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/text_to_video_synthesis/test_video_to_video.py b/tests/pipelines/text_to_video_synthesis/test_video_to_video.py index f01bb481ac73..07d48eba5574 100644 --- a/tests/pipelines/text_to_video_synthesis/test_video_to_video.py +++ b/tests/pipelines/text_to_video_synthesis/test_video_to_video.py @@ -33,7 +33,6 @@ is_flaky, nightly, numpy_cosine_similarity_distance, - print_tensor_test, skip_mps, torch_device, ) @@ -218,6 +217,8 @@ def test_two_step_model(self): generator = torch.Generator(device="cpu").manual_seed(0) video_frames = pipe(prompt, video=video, generator=generator, num_inference_steps=3, output_type="np").frames - expected_array = np.array([0.17114258, 0.13720703, 0.08886719, 0.14819336, 0.1730957, 0.24584961, 0.22021484, 0.35180664, 0.2607422]) + expected_array = np.array( + [0.17114258, 0.13720703, 0.08886719, 0.14819336, 0.1730957, 0.24584961, 0.22021484, 0.35180664, 0.2607422] + ) output_array = video_frames[0, 0, :3, :3, 0].flatten() assert numpy_cosine_similarity_distance(expected_array, output_array) < 1e-3