Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions src/diffusers/pipelines/animatediff/pipeline_animatediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,7 @@
"""


def tensor2vid(video: torch.Tensor, processor, output_type="np"):
# Based on:
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78

def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
Expand All @@ -79,6 +76,15 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):

outputs.append(batch_output)

if output_type == "np":
outputs = np.stack(outputs)

elif output_type == "pt":
outputs = torch.stack(outputs)

elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")

return outputs


Expand Down Expand Up @@ -805,11 +811,7 @@ def _retrieve_video_frames(self, latents, output_type, return_dict):
return AnimateDiffPipelineOutput(frames=latents)

video_tensor = self.decode_latents(latents)

if output_type == "pt":
video = video_tensor
else:
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)

if not return_dict:
return (video,)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,8 @@ def _append_dims(x, target_dims):
return x[(...,) + (None,) * dims_to_append]


def tensor2vid(video: torch.Tensor, processor, output_type="np"):
# Based on:
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78

# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
Expand All @@ -53,7 +51,13 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
outputs.append(batch_output)

if output_type == "np":
return np.stack(outputs)
outputs = np.stack(outputs)

elif output_type == "pt":
outputs = torch.stack(outputs)

elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")

return outputs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch
from transformers import CLIPTextModel, CLIPTokenizer

from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet3DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
Expand Down Expand Up @@ -58,22 +59,26 @@
"""


def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
# This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
# reshape to ncfhw
mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
# unnormalize back to [0,1]
video = video.mul_(std).add_(mean)
video.clamp_(0, 1)
# prepare the final outputs
i, c, f, h, w = video.shape
images = video.permute(2, 3, 0, 4, 1).reshape(
f, h, i * w, c
) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
return images
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = processor.postprocess(batch_vid, output_type)

outputs.append(batch_output)

if output_type == "np":
outputs = np.stack(outputs)

elif output_type == "pt":
outputs = torch.stack(outputs)

elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")

return outputs


class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
Expand Down Expand Up @@ -122,6 +127,7 @@ def __init__(
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
def enable_vae_slicing(self):
Expand Down Expand Up @@ -717,11 +723,7 @@ def __call__(
return TextToVideoSDPipelineOutput(frames=latents)

video_tensor = self.decode_latents(latents)

if output_type == "pt":
video = video_tensor
else:
video = tensor2vid(video_tensor)
video = tensor2vid(video_tensor, self.image_processor, output_type)

# Offload all models
self.maybe_free_model_hooks()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch
from transformers import CLIPTextModel, CLIPTokenizer

from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet3DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
Expand Down Expand Up @@ -93,22 +94,26 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")


def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
# This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
# reshape to ncfhw
mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
# unnormalize back to [0,1]
video = video.mul_(std).add_(mean)
video.clamp_(0, 1)
# prepare the final outputs
i, c, f, h, w = video.shape
images = video.permute(2, 3, 0, 4, 1).reshape(
f, h, i * w, c
) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
return images
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = processor.postprocess(batch_vid, output_type)

outputs.append(batch_output)

if output_type == "np":
outputs = np.stack(outputs)

elif output_type == "pt":
outputs = torch.stack(outputs)

elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")

return outputs


def preprocess_video(video):
Expand Down Expand Up @@ -198,6 +203,7 @@ def __init__(
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
def enable_vae_slicing(self):
Expand Down Expand Up @@ -812,12 +818,11 @@ def __call__(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")

video_tensor = self.decode_latents(latents)
if output_type == "latent":
return TextToVideoSDPipelineOutput(frames=latents)

if output_type == "pt":
video = video_tensor
else:
video = tensor2vid(video_tensor)
video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type)

# Offload all models
self.maybe_free_model_hooks()
Expand Down
2 changes: 1 addition & 1 deletion tests/pipelines/animatediff/test_animatediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def test_free_init(self):
sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum()
max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_init)).max()
self.assertGreater(
sum_enabled, 1e2, "Enabling of FreeInit should lead to results different from the default pipeline results"
sum_enabled, 1e1, "Enabling of FreeInit should lead to results different from the default pipeline results"
)
self.assertLess(
max_diff_disabled,
Expand Down
23 changes: 11 additions & 12 deletions tests/pipelines/text_to_video_synthesis/test_text_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from diffusers.utils.testing_utils import (
enable_full_determinism,
load_numpy,
numpy_cosine_similarity_distance,
require_torch_gpu,
skip_mps,
slow,
Expand Down Expand Up @@ -141,10 +142,11 @@ def test_text_to_video_default_case(self):
inputs = self.get_dummy_inputs(device)
inputs["output_type"] = "np"
frames = sd_pipe(**inputs).frames
image_slice = frames[0][-3:, -3:, -1]

assert frames[0].shape == (32, 32, 3)
expected_slice = np.array([192.0, 44.0, 157.0, 140.0, 108.0, 104.0, 123.0, 144.0, 129.0])
image_slice = frames[0][0][-3:, -3:, -1]

assert frames[0][0].shape == (32, 32, 3)
expected_slice = np.array([0.7537, 0.1752, 0.6157, 0.5508, 0.4240, 0.4110, 0.4838, 0.5648, 0.5094])

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

Expand Down Expand Up @@ -183,7 +185,7 @@ def test_progress_bar(self):
class TextToVideoSDPipelineSlowTests(unittest.TestCase):
def test_two_step_model(self):
expected_video = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_to_video/video_2step.npy"
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text-to-video/video_2step.npy"
)

pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b")
Expand All @@ -192,10 +194,8 @@ def test_two_step_model(self):
prompt = "Spiderman is surfing"
generator = torch.Generator(device="cpu").manual_seed(0)

video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="pt").frames
video = video_frames.cpu().numpy()

assert np.abs(expected_video - video).mean() < 5e-2
video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").frames
assert numpy_cosine_similarity_distance(expected_video.flatten(), video_frames.flatten()) < 1e-4

def test_two_step_model_with_freeu(self):
expected_video = []
Expand All @@ -207,10 +207,9 @@ def test_two_step_model_with_freeu(self):
generator = torch.Generator(device="cpu").manual_seed(0)

pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="pt").frames
video = video_frames.cpu().numpy()
video = video[0, 0, -3:, -3:, -1].flatten()
video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").frames
video = video_frames[0, 0, -3:, -3:, -1].flatten()

expected_video = [-0.3102, -0.2477, -0.1772, -0.648, -0.6176, -0.5484, -0.0217, -0.056, -0.0177]
expected_video = [0.3643, 0.3455, 0.3831, 0.3923, 0.2978, 0.3247, 0.3278, 0.3201, 0.3475]

assert np.abs(expected_video - video).mean() < 5e-2
18 changes: 10 additions & 8 deletions tests/pipelines/text_to_video_synthesis/test_video_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,10 @@ def test_text_to_video_default_case(self):
inputs = self.get_dummy_inputs(device)
inputs["output_type"] = "np"
frames = sd_pipe(**inputs).frames
image_slice = frames[0][-3:, -3:, -1]
image_slice = frames[0][0][-3:, -3:, -1]

assert frames[0].shape == (32, 32, 3)
expected_slice = np.array([162.0, 136.0, 132.0, 140.0, 139.0, 137.0, 169.0, 134.0, 132.0])
assert frames[0][0].shape == (32, 32, 3)
expected_slice = np.array([0.6391, 0.5350, 0.5202, 0.5521, 0.5453, 0.5393, 0.6652, 0.5270, 0.5185])

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

Expand Down Expand Up @@ -214,9 +214,11 @@ def test_two_step_model(self):

prompt = "Spiderman is surfing"

video_frames = pipe(prompt, video=video, generator=generator, num_inference_steps=3, output_type="pt").frames

expected_array = np.array([-0.9770508, -0.8027344, -0.62646484, -0.8334961, -0.7573242])
output_array = video_frames.cpu().numpy()[0, 0, 0, 0, -5:]
generator = torch.Generator(device="cpu").manual_seed(0)
video_frames = pipe(prompt, video=video, generator=generator, num_inference_steps=3, output_type="np").frames

assert numpy_cosine_similarity_distance(expected_array, output_array) < 1e-2
expected_array = np.array(
[0.17114258, 0.13720703, 0.08886719, 0.14819336, 0.1730957, 0.24584961, 0.22021484, 0.35180664, 0.2607422]
)
output_array = video_frames[0, 0, :3, :3, 0].flatten()
assert numpy_cosine_similarity_distance(expected_array, output_array) < 1e-3