From d6912acd8fb0cb02f2fe0d16ff547d551185a0ed Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 19 Mar 2023 20:11:37 +0000 Subject: [PATCH 01/43] [MS Text To Video} Add first text to video --- .../pipeline_text_to_video_synth.py | 622 ++++++++++++++++++ 1 file changed, 622 insertions(+) create mode 100644 src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py new file mode 100644 index 000000000000..6caab78040a7 --- /dev/null +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -0,0 +1,622 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import TextToVideoMSPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import TextToVideoMSPipeline + + >>> pipe = TextToVideoMSPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +class TextToVideoMSPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.TextToVideoMSPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.TextToVideoMSPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.TextToVideoMSPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if output_type == "latent": + image = latents + has_nsfw_concept = None + elif output_type == "pil": + # 8. Post-processing + image = self.decode_latents(latents) + + # 10. Convert to PIL + image = self.numpy_to_pil(image) + else: + # 8. Post-processing + image = self.decode_latents(latents) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return TextToVideoMSPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) From bf1c935509f5ef62f6232d209890c3d464ee9f22 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 19 Mar 2023 20:38:06 +0000 Subject: [PATCH 02/43] upload --- src/diffusers/models/unet_3d_condition.py | 599 ++++++++++++++++++++++ 1 file changed, 599 insertions(+) create mode 100644 src/diffusers/models/unet_3d_condition.py diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py new file mode 100644 index 000000000000..535741adeaeb --- /dev/null +++ b/src/diffusers/models/unet_3d_condition.py @@ -0,0 +1,599 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, logging +from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps +from .modeling_utils import ModelMixin +from .unet_2d_blocks import ( + CrossAttnDownBlock2D, + CrossAttnUpBlock2D, + DownBlock2D, + UNetMidBlock2DCrossAttn, + UNetMidBlock2DSimpleCrossAttn, + UpBlock2D, + get_down_block, + get_up_block, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet3DConditionOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor + + +class UNet3DConditionModel(ModelMixin, ConfigMixin): + r""" + UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep + and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the models (such as downloading or saving, etc.) + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`, will skip the + mid block layer if `None`. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, it will skip the normalization and activation layers in post-processing + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to None): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, or `"projection"`. + num_class_embeds (`int`, *optional*, defaults to None): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, default to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + timestep_post_act (`str, *optional*, default to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, default to `None`): + The dimension of `cond_proj` layer in timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + time_embedding_type: str = "positional", + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + ): + super().__init__() + + self.sample_size = sample_size + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + if time_embedding_type == "fourier": + time_embed_dim = block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Pleaes make sure to use one of `fourier` or `positional`." + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": + self.mid_block = UNetMidBlock2DSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + self.conv_act = nn.SiLU() + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet3DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet3DConditionOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Returns: + [`~models.unet_2d_condition.UNet3DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet3DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if mid_block_additional_residual is not None: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) From 8a29fe689b797a8e7bd79c2b07e596c3fadba394 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 20 Mar 2023 10:14:49 +0000 Subject: [PATCH 03/43] make first model example --- .../convert_ms_text_to_video_to_diffusers.py | 131 ++++ src/diffusers/__init__.py | 1 + src/diffusers/models/__init__.py | 1 + src/diffusers/models/transformer_temp.py | 313 ++++++++ src/diffusers/models/unet_3d_blocks.py | 730 ++++++++++++++++++ src/diffusers/models/unet_3d_condition.py | 154 ++-- 6 files changed, 1234 insertions(+), 96 deletions(-) create mode 100644 scripts/convert_ms_text_to_video_to_diffusers.py create mode 100644 src/diffusers/models/transformer_temp.py create mode 100644 src/diffusers/models/unet_3d_blocks.py diff --git a/scripts/convert_ms_text_to_video_to_diffusers.py b/scripts/convert_ms_text_to_video_to_diffusers.py new file mode 100644 index 000000000000..59c7522bf838 --- /dev/null +++ b/scripts/convert_ms_text_to_video_to_diffusers.py @@ -0,0 +1,131 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Conversion script for the LDM checkpoints. """ + +import argparse +import torch +from diffusers import UNet3DConditionModel + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml + parser.add_argument( + "--original_config_file", + default=None, + type=str, + help="The YAML config file corresponding to the original architecture.", + ) + parser.add_argument( + "--num_in_channels", + default=None, + type=int, + help="The number of input channels. If `None` number of input channels will be automatically inferred.", + ) + parser.add_argument( + "--scheduler_type", + default="pndm", + type=str, + help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']", + ) + parser.add_argument( + "--pipeline_type", + default=None, + type=str, + help=( + "The pipeline type. One of 'FrozenOpenCLIPEmbedder', 'FrozenCLIPEmbedder', 'PaintByExample'" + ". If `None` pipeline will be automatically inferred." + ), + ) + parser.add_argument( + "--image_size", + default=None, + type=int, + help=( + "The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2" + " Base. Use 768 for Stable Diffusion v2." + ), + ) + parser.add_argument( + "--prediction_type", + default=None, + type=str, + help=( + "The prediction type that the model was trained on. Use 'epsilon' for Stable Diffusion v1.X and Stable" + " Diffusion v2 Base. Use 'v_prediction' for Stable Diffusion v2." + ), + ) + parser.add_argument( + "--extract_ema", + action="store_true", + help=( + "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights" + " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield" + " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." + ), + ) + parser.add_argument( + "--upcast_attention", + action="store_true", + help=( + "Whether the attention computation should always be upcasted. This is necessary when running stable" + " diffusion 2.1." + ), + ) + parser.add_argument( + "--from_safetensors", + action="store_true", + help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.", + ) + parser.add_argument( + "--to_safetensors", + action="store_true", + help="Whether to store pipeline in safetensors format or not.", + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") + parser.add_argument( + "--stable_unclip", + type=str, + default=None, + required=False, + help="Set if this is a stable unCLIP model. One of 'txt2img' or 'img2img'.", + ) + parser.add_argument( + "--stable_unclip_prior", + type=str, + default=None, + required=False, + help="Set if this is a stable unCLIP txt2img model. Selects which prior to use. If `--stable_unclip` is set to `txt2img`, the karlo prior (https://huggingface.co/kakaobrain/karlo-v1-alpha/tree/main/prior) is selected by default.", + ) + parser.add_argument( + "--clip_stats_path", + type=str, + help="Path to the clip stats file. Only required if the stable unclip model's config specifies `model.params.noise_aug_config.params.clip_stats_path`.", + required=False, + ) + parser.add_argument( + "--controlnet", action="store_true", default=None, help="Set flag if this is a controlnet checkpoint." + ) + args = parser.parse_args() + + unet = UNet3DConditionModel() + import ipdb; ipdb.set_trace() + + # checkpoint = torch.load(args.checkpoint_path, map_location="cpu") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f480b4100907..05aa5350e5cf 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -40,6 +40,7 @@ Transformer2DModel, UNet1DModel, UNet2DConditionModel, + UNet3DConditionModel, UNet2DModel, VQModel, ) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index e0b2cddd4bf9..752aeb409f57 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -25,6 +25,7 @@ from .unet_1d import UNet1DModel from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel + from .unet_3d_condition import UNet3DConditionModel from .vq_model import VQModel if is_flax_available(): diff --git a/src/diffusers/models/transformer_temp.py b/src/diffusers/models/transformer_temp.py new file mode 100644 index 000000000000..2a8843c75ae5 --- /dev/null +++ b/src/diffusers/models/transformer_temp.py @@ -0,0 +1,313 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..models.embeddings import ImagePositionalEmbeddings +from ..utils import BaseOutput, deprecate +from .attention import BasicTransformerBlock +from .embeddings import PatchEmbed +from .modeling_utils import ModelMixin + + +@dataclass +class TransformerTempModelOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`TransformerTempModel`] is discrete): + Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions + for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class TransformerTempModel(ModelMixin, ConfigMixin): + """ + Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual + embeddings) inputs. + + When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard + transformer action. Finally, reshape to image. + + When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional + embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict + classes of unnoised image. + + Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised + image do not contain a prediction for the masked pixel as the unnoised image cannot be masked. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # 1. TransformerTempModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "TransformerTempModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "TransformerTempModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + elif self.is_input_patches: + assert sample_size is not None, "TransformerTempModel over patched input must provide sample_size" + + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continous projections + if use_linear_projection: + self.proj_out = nn.Linear(inner_dim, in_channels) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches: + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + cross_attention_kwargs=None, + return_dict: bool = True, + ): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels + conditioning. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.transformer_2d.TransformerTempModelOutput`] or `tuple`: + [`~models.transformer_2d.TransformerTempModelOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # 1. Input + if self.is_input_continuous: + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + hidden_states = self.pos_embed(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + elif self.is_input_patches: + # TODO: cleanup! + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + + # unpatchify + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + if not return_dict: + return (output,) + + return TransformerTempModelOutput(sample=output) diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py new file mode 100644 index 000000000000..8db52c6bae3d --- /dev/null +++ b/src/diffusers/models/unet_3d_blocks.py @@ -0,0 +1,730 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn + +from .resnet import Downsample2D, ResnetBlock2D, Upsample2D +from .transformer_2d import Transformer2DModel +from .transformer_temp import TransformerTempModel + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", +): + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", +): + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + temp_convs = [ + TemporalConvBlock_v2( + in_channels, + in_channels, + dropout=0.1, + ) + ] + attentions = [] + + for _ in range(num_layers): + attentions.append( + Transformer2DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + attentions.append( + TransformerTempModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + + def forward( + self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + ): + for attn, resnet, temp_conv in zip(self.attentions, self.resnets, self.temp_convs): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + temp_convs = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvBlock_v2( + out_channels, + out_channels, + dropout=0.1, + ) + ) + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + attentions.append( + TransformerTempModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + ): + # TODO(Patrick, William) - attention mask is not used + output_states = () + + for resnet, temp_conv, attn in zip(self.resnets, self.temp_convs, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + temp_convs = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvBlock_v2( + out_channels, + out_channels, + dropout=0.1, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + hidden_states = temp_conv(hidden_states) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + temp_convs = [] + attentions = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvBlock_v2( + out_channels, + out_channels, + dropout=0.1, + ) + ) + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + attentions.append( + TransformerTempModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + cross_attention_kwargs=None, + upsample_size=None, + attention_mask=None, + ): + # TODO(Patrick, William) - attention mask is not used + for resnet, temp_conv, attn in zip(self.resnets, self.temp_convs, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + temp_convs = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvBlock_v2( + out_channels, + out_channels, + dropout=0.1, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + hidden_states = temp_conv(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class TemporalConvBlock_v2(nn.Module): + def __init__(self, + in_dim, + out_dim=None, + dropout=0.0, + use_image_dataset=False): + super(TemporalConvBlock_v2, self).__init__() + if out_dim is None: + out_dim = in_dim # int(1.5*in_dim) + self.in_dim = in_dim + self.out_dim = out_dim + self.use_image_dataset = use_image_dataset + + # conv layers + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))) + self.conv2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + self.conv3 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + self.conv4 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + + # zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.conv4[-1].weight) + nn.init.zeros_(self.conv4[-1].bias) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + + if self.use_image_dataset: + x = identity + 0.0 * x + else: + x = identity + x + return x diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index 535741adeaeb..839fa55d3e60 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -22,16 +22,8 @@ from ..utils import BaseOutput, logging from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin -from .unet_2d_blocks import ( - CrossAttnDownBlock2D, - CrossAttnUpBlock2D, - DownBlock2D, - UNetMidBlock2DCrossAttn, - UNetMidBlock2DSimpleCrossAttn, - UpBlock2D, - get_down_block, - get_up_block, -) +from .transformer_temp import TransformerTempModel +from .unet_3d_blocks import get_down_block, get_up_block, UNetMidBlock3DCrossAttn, UpBlock3D, DownBlock3D, CrossAttnUpBlock3D, CrossAttnDownBlock3D logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -114,39 +106,36 @@ def __init__( sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, - center_input_sample: bool = False, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, + center_input_sample: bool = False, # remove + flip_sin_to_cos: bool = True, # remove + freq_shift: int = 0, # remove down_block_types: Tuple[str] = ( - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "DownBlock2D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", ), - mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", - up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), - only_cross_attention: Union[bool, Tuple[bool]] = False, + mid_block_type: Optional[str] = "UNetMidBlock3DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, # remove block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - layers_per_block: int = 2, + layers_per_block: int = 1, downsample_padding: int = 1, - mid_block_scale_factor: float = 1, - act_fn: str = "silu", - norm_num_groups: Optional[int] = 32, - norm_eps: float = 1e-5, + mid_block_scale_factor: float = 1, # remove + act_fn: str = "silu", # remove + norm_num_groups: Optional[int] = 32, # remove + norm_eps: float = 1e-5, # remove cross_attention_dim: int = 1280, - attention_head_dim: Union[int, Tuple[int]] = 8, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - class_embed_type: Optional[str] = None, - num_class_embeds: Optional[int] = None, - upcast_attention: bool = False, - resnet_time_scale_shift: str = "default", - time_embedding_type: str = "positional", - timestep_post_act: Optional[str] = None, - time_cond_proj_dim: Optional[int] = None, - conv_in_kernel: int = 3, - conv_out_kernel: int = 3, - projection_class_embeddings_input_dim: Optional[int] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, # remove + use_linear_projection: bool = False, # remove + class_embed_type: Optional[str] = None, # remove + num_class_embeds: Optional[int] = None, # remove + upcast_attention: bool = False, # remvoe + resnet_time_scale_shift: str = "default", # remove + time_embedding_type: str = "positional", # remove + timestep_post_act: Optional[str] = None, # remove + time_cond_proj_dim: Optional[int] = None, # remove + projection_class_embeddings_input_dim: Optional[int] = None, # remove ): super().__init__() @@ -174,6 +163,8 @@ def __init__( ) # input + conv_in_kernel = 3 + conv_out_kernel = 3 conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding @@ -206,29 +197,17 @@ def __init__( cond_proj_dim=time_cond_proj_dim, ) - # class embedding - if class_embed_type is None and num_class_embeds is not None: - self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) - elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) - elif class_embed_type == "identity": - self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) - elif class_embed_type == "projection": - if projection_class_embeddings_input_dim is None: - raise ValueError( - "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" - ) - # The projection `class_embed_type` is the same as the timestep `class_embed_type` except - # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings - # 2. it projects from an arbitrary input dimension. - # - # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. - # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. - # As a result, `TimestepEmbedding` can be passed arbitrary vectors. - self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) - else: - self.class_embedding = None + self.transformer_in = TransformerTempModel( + num_attention_heads=8, + attention_head_dim=64, + in_channels=block_out_channels[0], + num_layers=1, + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + # class embedding self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) @@ -258,7 +237,7 @@ def __init__( cross_attention_dim=cross_attention_dim, attn_num_head_channels=attention_head_dim[i], downsample_padding=downsample_padding, - dual_cross_attention=dual_cross_attention, + dual_cross_attention=False, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, @@ -267,37 +246,20 @@ def __init__( self.down_blocks.append(down_block) # mid - if mid_block_type == "UNetMidBlock2DCrossAttn": - self.mid_block = UNetMidBlock2DCrossAttn( - in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift=resnet_time_scale_shift, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[-1], - resnet_groups=norm_num_groups, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - ) - elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": - self.mid_block = UNetMidBlock2DSimpleCrossAttn( - in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[-1], - resnet_groups=norm_num_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif mid_block_type is None: - self.mid_block = None - else: - raise ValueError(f"unknown mid_block_type : {mid_block_type}") + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=False, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) # count how many layers upsample the images self.num_upsamplers = 0 @@ -335,7 +297,7 @@ def __init__( resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, attn_num_head_channels=reversed_attention_head_dim[i], - dual_cross_attention=dual_cross_attention, + dual_cross_attention=False, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, @@ -425,7 +387,7 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): module.gradient_checkpointing = value def forward( @@ -478,8 +440,8 @@ def forward( attention_mask = attention_mask.unsqueeze(1) # 0. center input if necessary - if self.config.center_input_sample: - sample = 2 * sample - 1.0 +# if self.config.center_input_sample: +# sample = 2 * sample - 1.0 # 1. time timesteps = timestep From 5973584d7f76354c588c307d877bccb4aa38da5a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 20 Mar 2023 11:11:01 +0000 Subject: [PATCH 04/43] match unet3d params --- src/diffusers/models/attention.py | 7 +-- src/diffusers/models/transformer_temp.py | 55 +++-------------------- src/diffusers/models/unet_3d_blocks.py | 30 +++++++++++-- src/diffusers/models/unet_3d_condition.py | 5 +-- 4 files changed, 39 insertions(+), 58 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index aa10bdd0e952..06be60e05219 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -202,6 +202,7 @@ def __init__( num_embeds_ada_norm: Optional[int] = None, attention_bias: bool = False, only_cross_attention: bool = False, + double_self_attention: bool = False, upcast_attention: bool = False, norm_elementwise_affine: bool = True, norm_type: str = "layer_norm", @@ -233,10 +234,10 @@ def __init__( self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) # 2. Cross-Attn - if cross_attention_dim is not None: + if cross_attention_dim is not None or double_self_attention: self.attn2 = Attention( query_dim=dim, - cross_attention_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, @@ -253,7 +254,7 @@ def __init__( else: self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - if cross_attention_dim is not None: + if cross_attention_dim is not None or double_self_attention: # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during # the second cross attention block. diff --git a/src/diffusers/models/transformer_temp.py b/src/diffusers/models/transformer_temp.py index 2a8843c75ae5..0c7e3dece0c0 100644 --- a/src/diffusers/models/transformer_temp.py +++ b/src/diffusers/models/transformer_temp.py @@ -19,10 +19,8 @@ from torch import nn from ..configuration_utils import ConfigMixin, register_to_config -from ..models.embeddings import ImagePositionalEmbeddings from ..utils import BaseOutput, deprecate from .attention import BasicTransformerBlock -from .embeddings import PatchEmbed from .modeling_utils import ModelMixin @@ -98,6 +96,7 @@ def __init__( upcast_attention: bool = False, norm_type: str = "layer_norm", norm_elementwise_affine: bool = True, + double_self_attention: bool = True, ): super().__init__() self.use_linear_projection = use_linear_projection @@ -139,40 +138,10 @@ def __init__( ) # 2. Define input layers - if self.is_input_continuous: - self.in_channels = in_channels - - self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) - if use_linear_projection: - self.proj_in = nn.Linear(in_channels, inner_dim) - else: - self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) - elif self.is_input_vectorized: - assert sample_size is not None, "TransformerTempModel over discrete input must provide sample_size" - assert num_vector_embeds is not None, "TransformerTempModel over discrete input must provide num_embed" - - self.height = sample_size - self.width = sample_size - self.num_vector_embeds = num_vector_embeds - self.num_latent_pixels = self.height * self.width - - self.latent_image_embedding = ImagePositionalEmbeddings( - num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width - ) - elif self.is_input_patches: - assert sample_size is not None, "TransformerTempModel over patched input must provide sample_size" - - self.height = sample_size - self.width = sample_size + self.in_channels = in_channels - self.patch_size = patch_size - self.pos_embed = PatchEmbed( - height=sample_size, - width=sample_size, - patch_size=patch_size, - in_channels=in_channels, - embed_dim=inner_dim, - ) + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) # 3. Define transformers blocks self.transformer_blocks = nn.ModuleList( @@ -187,6 +156,7 @@ def __init__( num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, only_cross_attention=only_cross_attention, + double_self_attention=double_self_attention, upcast_attention=upcast_attention, norm_type=norm_type, norm_elementwise_affine=norm_elementwise_affine, @@ -196,20 +166,7 @@ def __init__( ) # 4. Define output layers - self.out_channels = in_channels if out_channels is None else out_channels - if self.is_input_continuous: - # TODO: should use out_channels for continous projections - if use_linear_projection: - self.proj_out = nn.Linear(inner_dim, in_channels) - else: - self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) - elif self.is_input_vectorized: - self.norm_out = nn.LayerNorm(inner_dim) - self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) - elif self.is_input_patches: - self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) - self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) - self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + self.proj_out = nn.Linear(inner_dim, in_channels) def forward( self, diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index 8db52c6bae3d..dcfa35957565 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -206,6 +206,28 @@ def __init__( upcast_attention=upcast_attention, ) ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvBlock_v2( + in_channels, + in_channels, + dropout=0.1, + ) + ) + self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) self.attentions = nn.ModuleList(attentions) @@ -213,7 +235,9 @@ def __init__( def forward( self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None ): - for attn, resnet, temp_conv in zip(self.attentions, self.resnets, self.temp_convs): + hidden_states = self.resnets[0](hidden_states, temb) + hidden_states = self.temp_convs[0](hidden_states) + for attn, resnet, temp_conv in zip(self.attentions, self.resnets[1:], self.temp_convs[1:]): hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -295,8 +319,8 @@ def __init__( attentions.append( TransformerTempModel( attn_num_head_channels, - in_channels // attn_num_head_channels, - in_channels=in_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index 839fa55d3e60..ac11336f5ce6 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -119,13 +119,13 @@ def __init__( up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"), only_cross_attention: Union[bool, Tuple[bool]] = False, # remove block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - layers_per_block: int = 1, + layers_per_block: int = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, # remove act_fn: str = "silu", # remove norm_num_groups: Optional[int] = 32, # remove norm_eps: float = 1e-5, # remove - cross_attention_dim: int = 1280, + cross_attention_dim: int = 1024, attention_head_dim: Union[int, Tuple[int]] = 8, # remove use_linear_projection: bool = False, # remove class_embed_type: Optional[str] = None, # remove @@ -202,7 +202,6 @@ def __init__( attention_head_dim=64, in_channels=block_out_channels[0], num_layers=1, - cross_attention_dim=cross_attention_dim, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) From d91862d2d4d11024f7417a45f27fc00ce164e32e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 20 Mar 2023 15:56:26 +0000 Subject: [PATCH 05/43] make sure weights are correcctly converted --- .../convert_ms_text_to_video_to_diffusers.py | 387 +++++++++++++++++- src/diffusers/models/unet_3d_blocks.py | 33 +- src/diffusers/models/unet_3d_condition.py | 2 +- .../stable_diffusion/convert_from_ckpt.py | 3 +- 4 files changed, 415 insertions(+), 10 deletions(-) diff --git a/scripts/convert_ms_text_to_video_to_diffusers.py b/scripts/convert_ms_text_to_video_to_diffusers.py index 59c7522bf838..cb97e8e103e9 100644 --- a/scripts/convert_ms_text_to_video_to_diffusers.py +++ b/scripts/convert_ms_text_to_video_to_diffusers.py @@ -19,6 +19,379 @@ from diffusers import UNet3DConditionModel +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + weight = old_checkpoint[path["old"]] + names = ["proj_attn.weight"] + names_2 = ["proj_out.weight", "proj_in.weight"] + if any(k in new_path for k in names): + checkpoint[new_path] = weight[:, :, 0] + elif any(k in new_path for k in names_2) and len(weight.shape) > 2 and ".attentions." not in new_path: + checkpoint[new_path] = weight[:, :, 0] + else: + checkpoint[new_path] = weight + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_temp_conv_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + mapping.append({"old": old_item, "new": old_item}) + + return mapping + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + if "temopral_conv" not in old_item: + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + unet_key = "model.diffusion_model." + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + print(f"Checkpoint {path} has both EMA and non-EMA weights.") + print( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + print( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + if config["class_embed_type"] is None: + # No parameters to port + ... + elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": + new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + else: + raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + first_temp_attention = [v for v in unet_state_dict if v.startswith("input_blocks.0.1")] + paths = renew_attention_paths(first_temp_attention) + meta_path = {"old": "input_blocks.0.1", "new": "transformer_in"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + temp_attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.2" in key] + + if f"input_blocks.{i}.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.op.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + temporal_convs = [ + key for key in resnets if "temopral_conv" in key + ] + paths = renew_temp_conv_paths(temporal_convs) + meta_path = {"old": f"input_blocks.{i}.0.temopral_conv", "new": f"down_blocks.{block_id}.temp_convs.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(temp_attentions): + paths = renew_attention_paths(temp_attentions) + meta_path = {"old": f"input_blocks.{i}.2", "new": f"down_blocks.{block_id}.temp_attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + temporal_convs_0 = [ + key for key in resnet_0 if "temopral_conv" in key + ] + attentions = middle_blocks[1] + temp_attentions = middle_blocks[2] + resnet_1 = middle_blocks[3] + temporal_convs_1 = [ + key for key in resnet_1 if "temopral_conv" in key + ] + + resnet_0_paths = renew_resnet_paths(resnet_0) + meta_path = {"old": "middle_block.0", "new": "mid_block.resnets.0"} + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]) + + temp_conv_0_paths = renew_temp_conv_paths(temporal_convs_0) + meta_path = {"old": "middle_block.0.temopral_conv", "new": "mid_block.temp_convs.0"} + assign_to_checkpoint(temp_conv_0_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]) + + resnet_1_paths = renew_resnet_paths(resnet_1) + meta_path = {"old": "middle_block.3", "new": "mid_block.resnets.1"} + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]) + + temp_conv_1_paths = renew_temp_conv_paths(temporal_convs_1) + meta_path = {"old": "middle_block.3.temopral_conv", "new": "mid_block.temp_convs.1"} + assign_to_checkpoint(temp_conv_1_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + temp_attentions_paths = renew_attention_paths(temp_attentions) + meta_path = {"old": "middle_block.2", "new": "mid_block.temp_attentions.0"} + assign_to_checkpoint( + temp_attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + temp_attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.2" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + temporal_convs = [ + key for key in resnets if "temopral_conv" in key + ] + paths = renew_temp_conv_paths(temporal_convs) + meta_path = {"old": f"output_blocks.{i}.0.temopral_conv", "new": f"up_blocks.{block_id}.temp_convs.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(temp_attentions): + paths = renew_attention_paths(temp_attentions) + meta_path = { + "old": f"output_blocks.{i}.2", + "new": f"up_blocks.{block_id}.temp_attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + new_checkpoint[new_path] = unet_state_dict[old_path] + + temopral_conv_paths = [l for l in output_block_layers if "temopral_conv" in l] + for path in temopral_conv_paths: + pruned_path = path.split("temopral_conv.")[-1] + old_path = ".".join(["output_blocks", str(i), str(block_id) , "temopral_conv", pruned_path]) + new_path = ".".join(["up_blocks", str(block_id), "temp_convs", str(layer_in_block_id), pruned_path]) + new_checkpoint[new_path] = unet_state_dict[old_path] + + return new_checkpoint + + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -125,7 +498,17 @@ ) args = parser.parse_args() + unet_checkpoint = torch.load(args.checkpoint_path, map_location="cpu") unet = UNet3DConditionModel() - import ipdb; ipdb.set_trace() - # checkpoint = torch.load(args.checkpoint_path, map_location="cpu") + converted_ckpt = convert_ldm_unet_checkpoint(unet_checkpoint, unet.config) + + diff_0 = set(unet.state_dict().keys()) - set(converted_ckpt.keys()) + diff_1 = set(converted_ckpt.keys()) - set(unet.state_dict().keys()) + + assert len(diff_0) == len(diff_1) == 0, "Converted weights don't match" + + # load state_dict + unet.load_state_dict(converted_ckpt) + + # -- finish converting the unet -- diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index dcfa35957565..9c11e03fd70f 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -180,6 +180,7 @@ def __init__( ) ] attentions = [] + temp_attentions = [] for _ in range(num_layers): attentions.append( @@ -194,7 +195,7 @@ def __init__( upcast_attention=upcast_attention, ) ) - attentions.append( + temp_attentions.append( TransformerTempModel( attn_num_head_channels, in_channels // attn_num_head_channels, @@ -231,18 +232,24 @@ def __init__( self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) def forward( self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None ): hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.temp_convs[0](hidden_states) - for attn, resnet, temp_conv in zip(self.attentions, self.resnets[1:], self.temp_convs[1:]): + for attn, temp_attn, resnet, temp_conv in zip(self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]): hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, ).sample + hidden_states = temp_attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states) @@ -275,6 +282,7 @@ def __init__( super().__init__() resnets = [] attentions = [] + temp_attentions = [] temp_convs = [] self.has_cross_attention = True @@ -316,7 +324,7 @@ def __init__( upcast_attention=upcast_attention, ) ) - attentions.append( + temp_attentions.append( TransformerTempModel( attn_num_head_channels, out_channels // attn_num_head_channels, @@ -331,6 +339,7 @@ def __init__( self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) if add_downsample: self.downsamplers = nn.ModuleList( @@ -351,7 +360,7 @@ def forward( # TODO(Patrick, William) - attention mask is not used output_states = () - for resnet, temp_conv, attn in zip(self.resnets, self.temp_convs, self.attentions): + for resnet, temp_conv, attn, temp_attn in zip(self.resnets, self.temp_convs, self.attentions, self.temp_attentions): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): @@ -378,6 +387,11 @@ def custom_forward(*inputs): encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, ).sample + hidden_states = temp_attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample output_states += (hidden_states,) @@ -507,6 +521,7 @@ def __init__( resnets = [] temp_convs = [] attentions = [] + temp_attentions = [] self.has_cross_attention = True self.attn_num_head_channels = attn_num_head_channels @@ -549,7 +564,7 @@ def __init__( upcast_attention=upcast_attention, ) ) - attentions.append( + temp_attentions.append( TransformerTempModel( attn_num_head_channels, out_channels // attn_num_head_channels, @@ -565,6 +580,7 @@ def __init__( self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) @@ -584,7 +600,7 @@ def forward( attention_mask=None, ): # TODO(Patrick, William) - attention mask is not used - for resnet, temp_conv, attn in zip(self.resnets, self.temp_convs, self.attentions): + for resnet, temp_conv, attn, temp_attn in zip(self.resnets, self.temp_convs, self.attentions, self.temp_attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] @@ -616,6 +632,11 @@ def custom_forward(*inputs): encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, ).sample + hidden_states = temp_attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample if self.upsamplers is not None: for upsampler in self.upsamplers: diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index ac11336f5ce6..9010b49019a9 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -127,7 +127,7 @@ def __init__( norm_eps: float = 1e-5, # remove cross_attention_dim: int = 1024, attention_head_dim: Union[int, Tuple[int]] = 8, # remove - use_linear_projection: bool = False, # remove + use_linear_projection: bool = True, # remove class_embed_type: Optional[str] = None, # remove num_class_embeds: Optional[int] = None, # remove upcast_attention: bool = False, # remvoe diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index ef4598433f82..f26e466ead34 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -200,7 +200,8 @@ def assign_to_checkpoint( # Global renaming happens here new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") - new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + new_path = new_path.replace("middle_block.2", "mid_block.temp_attentions.0") + new_path = new_path.replace("middle_block.3", "mid_block.resnets.1") if additional_replacements is not None: for replacement in additional_replacements: From aeab5adc12a82ed428aa0e9c30d4a21731ebc789 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 21 Mar 2023 08:13:01 +0000 Subject: [PATCH 06/43] improve --- scripts/convert_ms_text_to_video_to_diffusers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scripts/convert_ms_text_to_video_to_diffusers.py b/scripts/convert_ms_text_to_video_to_diffusers.py index cb97e8e103e9..6948e2097df5 100644 --- a/scripts/convert_ms_text_to_video_to_diffusers.py +++ b/scripts/convert_ms_text_to_video_to_diffusers.py @@ -511,4 +511,6 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False # load state_dict unet.load_state_dict(converted_ckpt) + unet.save_pretrained(args.dump_path) + # -- finish converting the unet -- From d9dd98ccd28f62b681c5574b4d0d61e8e6389a09 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 21 Mar 2023 16:31:20 +0000 Subject: [PATCH 07/43] forward pass works, but diff result --- src/diffusers/models/transformer_temp.py | 66 +++-------- src/diffusers/models/unet_3d_blocks.py | 135 ++++++---------------- src/diffusers/models/unet_3d_condition.py | 29 ++--- 3 files changed, 62 insertions(+), 168 deletions(-) diff --git a/src/diffusers/models/transformer_temp.py b/src/diffusers/models/transformer_temp.py index 0c7e3dece0c0..6735a324ffd2 100644 --- a/src/diffusers/models/transformer_temp.py +++ b/src/diffusers/models/transformer_temp.py @@ -174,6 +174,7 @@ def forward( encoder_hidden_states=None, timestep=None, class_labels=None, + num_frames=1, cross_attention_kwargs=None, return_dict: bool = True, ): @@ -199,23 +200,18 @@ def forward( returning a tuple, the first element is the sample tensor. """ # 1. Input - if self.is_input_continuous: - batch, _, height, width = hidden_states.shape - residual = hidden_states + batch_frames, channel, height, width = hidden_states.shape + batch_size = batch_frames // num_frames - hidden_states = self.norm(hidden_states) - if not self.use_linear_projection: - hidden_states = self.proj_in(hidden_states) - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - else: - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - hidden_states = self.proj_in(hidden_states) - elif self.is_input_vectorized: - hidden_states = self.latent_image_embedding(hidden_states) - elif self.is_input_patches: - hidden_states = self.pos_embed(hidden_states) + residual = hidden_states + + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) + + hidden_states = self.proj_in(hidden_states) # 2. Blocks for block in self.transformer_blocks: @@ -228,41 +224,11 @@ def forward( ) # 3. Output - if self.is_input_continuous: - if not self.use_linear_projection: - hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - hidden_states = self.proj_out(hidden_states) - else: - hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - - output = hidden_states + residual - elif self.is_input_vectorized: - hidden_states = self.norm_out(hidden_states) - logits = self.out(hidden_states) - # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) - logits = logits.permute(0, 2, 1) + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states[None, None, :].reshape(batch_size, height, width, channel, num_frames).permute(0, 3, 4, 1, 2).contiguous() + hidden_states = hidden_states.reshape(batch_frames, channel, height, width) - # log(p(x_0)) - output = F.log_softmax(logits.double(), dim=1).float() - elif self.is_input_patches: - # TODO: cleanup! - conditioning = self.transformer_blocks[0].norm1.emb( - timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] - hidden_states = self.proj_out_2(hidden_states) - - # unpatchify - height = width = int(hidden_states.shape[1] ** 0.5) - hidden_states = hidden_states.reshape( - shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) - ) - hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) - output = hidden_states.reshape( - shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) - ) + output = hidden_states + residual if not return_dict: return (output,) diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index 9c11e03fd70f..f81d6c96f59a 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -235,23 +235,19 @@ def __init__( self.temp_attentions = nn.ModuleList(temp_attentions) def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, num_frames=1, cross_attention_kwargs=None ): hidden_states = self.resnets[0](hidden_states, temb) - hidden_states = self.temp_convs[0](hidden_states) + hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames) for attn, temp_attn, resnet, temp_conv in zip(self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]): hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, ).sample - hidden_states = temp_attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - ).sample + hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample hidden_states = resnet(hidden_states, temb) - hidden_states = temp_conv(hidden_states) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) return hidden_states @@ -355,43 +351,20 @@ def __init__( self.gradient_checkpointing = False def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, num_frames=1, cross_attention_kwargs=None ): # TODO(Patrick, William) - attention mask is not used output_states = () for resnet, temp_conv, attn, temp_attn in zip(self.resnets, self.temp_convs, self.attentions, self.temp_attentions): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = temp_conv(hidden_states) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - hidden_states = temp_attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - ).sample + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample output_states += (hidden_states,) @@ -465,23 +438,12 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, temb=None): + def forward(self, hidden_states, temb=None, num_frames=1): output_states = () for resnet, temp_conv in zip(self.resnets, self.temp_convs): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - else: - hidden_states = resnet(hidden_states, temb) - - hidden_states = temp_conv(hidden_states) + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) output_states += (hidden_states,) @@ -595,9 +557,10 @@ def forward( res_hidden_states_tuple, temb=None, encoder_hidden_states=None, - cross_attention_kwargs=None, upsample_size=None, attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, ): # TODO(Patrick, William) - attention mask is not used for resnet, temp_conv, attn, temp_attn in zip(self.resnets, self.temp_convs, self.attentions, self.temp_attentions): @@ -606,37 +569,14 @@ def forward( res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = temp_conv(hidden_states) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - hidden_states = temp_attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - ).sample + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -702,26 +642,15 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1): for resnet, temp_conv in zip(self.resnets, self.temp_convs): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - else: - hidden_states = resnet(hidden_states, temb) - - hidden_states = temp_conv(hidden_states) + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -761,7 +690,9 @@ def __init__(self, nn.init.zeros_(self.conv4[-1].weight) nn.init.zeros_(self.conv4[-1].bias) - def forward(self, x): + def forward(self, x, num_frames=1): + x = x[None, :].reshape((-1, num_frames) + x.shape[1:]).permute(0, 2, 1, 3, 4) + identity = x x = self.conv1(x) x = self.conv2(x) @@ -772,4 +703,6 @@ def forward(self, x): x = identity + 0.0 * x else: x = identity + x + + x = x.permute(0, 2, 1, 3, 4).reshape((x.shape[0] * x.shape[2], -1) + x.shape[3:]) return x diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index 9010b49019a9..62721c22505e 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -438,10 +438,6 @@ def forward( attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) - # 0. center input if necessary -# if self.config.center_input_sample: -# sample = 2 * sample - 1.0 - # 1. time timesteps = timestep if not torch.is_tensor(timesteps): @@ -457,6 +453,7 @@ def forward( timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + num_frames = sample.shape[2] timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) @@ -467,20 +464,15 @@ def forward( t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb, timestep_cond) - - if self.class_embedding is not None: - if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") - - if self.config.class_embed_type == "timestep": - class_labels = self.time_proj(class_labels) - - class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) - emb = emb + class_emb + emb = emb.repeat_interleave(repeats=num_frames, dim=0) + encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) # 2. pre-process + sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) sample = self.conv_in(sample) + sample = self.transformer_in(sample, num_frames=num_frames).sample + # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: @@ -490,10 +482,11 @@ def forward( temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, + num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, ) else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) down_block_res_samples += res_samples @@ -515,6 +508,7 @@ def forward( emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, + num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, ) @@ -539,13 +533,14 @@ def forward( temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, ) else: sample = upsample_block( - hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, num_frames=num_frames ) # 6. post-process From 40c80e2d7d3f308859e10e57524ba6dcf991e617 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 21 Mar 2023 19:05:46 +0000 Subject: [PATCH 08/43] make forward work --- src/diffusers/models/unet_3d_blocks.py | 8 ++++---- src/diffusers/models/unet_3d_condition.py | 20 +++++++++++++++----- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index f81d6c96f59a..a3349d071b24 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -309,8 +309,8 @@ def __init__( ) attentions.append( Transformer2DModel( - attn_num_head_channels, out_channels // attn_num_head_channels, + attn_num_head_channels, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, @@ -322,8 +322,8 @@ def __init__( ) temp_attentions.append( TransformerTempModel( - attn_num_head_channels, out_channels // attn_num_head_channels, + attn_num_head_channels, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, @@ -515,8 +515,8 @@ def __init__( ) attentions.append( Transformer2DModel( - attn_num_head_channels, out_channels // attn_num_head_channels, + attn_num_head_channels, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, @@ -528,8 +528,8 @@ def __init__( ) temp_attentions.append( TransformerTempModel( - attn_num_head_channels, out_channels // attn_num_head_channels, + attn_num_head_channels, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index 62721c22505e..c6b1bb825ace 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -121,12 +121,12 @@ def __init__( block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, - mid_block_scale_factor: float = 1, # remove - act_fn: str = "silu", # remove - norm_num_groups: Optional[int] = 32, # remove + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, # remove cross_attention_dim: int = 1024, - attention_head_dim: Union[int, Tuple[int]] = 8, # remove + attention_head_dim: Union[int, Tuple[int]] = 64, use_linear_projection: bool = True, # remove class_embed_type: Optional[str] = None, # remove num_class_embeds: Optional[int] = None, # remove @@ -199,7 +199,7 @@ def __init__( self.transformer_in = TransformerTempModel( num_attention_heads=8, - attention_head_dim=64, + attention_head_dim=attention_head_dim, in_channels=block_out_channels[0], num_layers=1, use_linear_projection=use_linear_projection, @@ -474,6 +474,7 @@ def forward( sample = self.transformer_in(sample, num_frames=num_frames).sample # 3. down + print("0", sample.abs().sum()) down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: @@ -501,6 +502,8 @@ def forward( down_block_res_samples = new_down_block_res_samples + print("1", sample.abs().sum()) + # 4. mid if self.mid_block is not None: sample = self.mid_block( @@ -515,6 +518,8 @@ def forward( if mid_block_additional_residual is not None: sample = sample + mid_block_additional_residual + print("2", sample.abs().sum()) + # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 @@ -543,12 +548,17 @@ def forward( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, num_frames=num_frames ) + print("3", sample.abs().sum()) + # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample) sample = self.conv_act(sample) + sample = self.conv_out(sample) + print("4", sample.abs().sum()) + if not return_dict: return (sample,) From c4f0aebb90c7e855fd18e6199a74258475ee17ff Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 21 Mar 2023 20:12:25 +0000 Subject: [PATCH 09/43] fix more --- .../convert_ms_text_to_video_to_diffusers.py | 55 +++++++----- src/diffusers/__init__.py | 3 +- src/diffusers/models/transformer_temp.py | 8 +- src/diffusers/models/unet_3d_blocks.py | 59 ++++++++---- src/diffusers/models/unet_3d_condition.py | 19 +++- src/diffusers/pipelines/__init__.py | 1 + .../text_to_video_synthesis/__init__.py | 29 ++++++ .../pipeline_text_to_video_synth.py | 90 +++++++++++-------- 8 files changed, 181 insertions(+), 83 deletions(-) create mode 100644 src/diffusers/pipelines/text_to_video_synthesis/__init__.py diff --git a/scripts/convert_ms_text_to_video_to_diffusers.py b/scripts/convert_ms_text_to_video_to_diffusers.py index 6948e2097df5..91e699932558 100644 --- a/scripts/convert_ms_text_to_video_to_diffusers.py +++ b/scripts/convert_ms_text_to_video_to_diffusers.py @@ -15,7 +15,9 @@ """ Conversion script for the LDM checkpoints. """ import argparse + import torch + from diffusers import UNet3DConditionModel @@ -191,9 +193,7 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False first_temp_attention = [v for v in unet_state_dict if v.startswith("input_blocks.0.1")] paths = renew_attention_paths(first_temp_attention) meta_path = {"old": "input_blocks.0.1", "new": "transformer_in"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] @@ -245,11 +245,12 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) - temporal_convs = [ - key for key in resnets if "temopral_conv" in key - ] + temporal_convs = [key for key in resnets if "temopral_conv" in key] paths = renew_temp_conv_paths(temporal_convs) - meta_path = {"old": f"input_blocks.{i}.0.temopral_conv", "new": f"down_blocks.{block_id}.temp_convs.{layer_in_block_id}"} + meta_path = { + "old": f"input_blocks.{i}.0.temopral_conv", + "new": f"down_blocks.{block_id}.temp_convs.{layer_in_block_id}", + } assign_to_checkpoint( paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) @@ -263,37 +264,44 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False if len(temp_attentions): paths = renew_attention_paths(temp_attentions) - meta_path = {"old": f"input_blocks.{i}.2", "new": f"down_blocks.{block_id}.temp_attentions.{layer_in_block_id}"} + meta_path = { + "old": f"input_blocks.{i}.2", + "new": f"down_blocks.{block_id}.temp_attentions.{layer_in_block_id}", + } assign_to_checkpoint( paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) resnet_0 = middle_blocks[0] - temporal_convs_0 = [ - key for key in resnet_0 if "temopral_conv" in key - ] + temporal_convs_0 = [key for key in resnet_0 if "temopral_conv" in key] attentions = middle_blocks[1] temp_attentions = middle_blocks[2] resnet_1 = middle_blocks[3] - temporal_convs_1 = [ - key for key in resnet_1 if "temopral_conv" in key - ] + temporal_convs_1 = [key for key in resnet_1 if "temopral_conv" in key] resnet_0_paths = renew_resnet_paths(resnet_0) meta_path = {"old": "middle_block.0", "new": "mid_block.resnets.0"} - assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]) + assign_to_checkpoint( + resnet_0_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path] + ) temp_conv_0_paths = renew_temp_conv_paths(temporal_convs_0) meta_path = {"old": "middle_block.0.temopral_conv", "new": "mid_block.temp_convs.0"} - assign_to_checkpoint(temp_conv_0_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]) + assign_to_checkpoint( + temp_conv_0_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path] + ) resnet_1_paths = renew_resnet_paths(resnet_1) meta_path = {"old": "middle_block.3", "new": "mid_block.resnets.1"} - assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]) + assign_to_checkpoint( + resnet_1_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path] + ) temp_conv_1_paths = renew_temp_conv_paths(temporal_convs_1) meta_path = {"old": "middle_block.3.temopral_conv", "new": "mid_block.temp_convs.1"} - assign_to_checkpoint(temp_conv_1_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]) + assign_to_checkpoint( + temp_conv_1_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path] + ) attentions_paths = renew_attention_paths(attentions) meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} @@ -333,11 +341,12 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) - temporal_convs = [ - key for key in resnets if "temopral_conv" in key - ] + temporal_convs = [key for key in resnets if "temopral_conv" in key] paths = renew_temp_conv_paths(temporal_convs) - meta_path = {"old": f"output_blocks.{i}.0.temopral_conv", "new": f"up_blocks.{block_id}.temp_convs.{layer_in_block_id}"} + meta_path = { + "old": f"output_blocks.{i}.0.temopral_conv", + "new": f"up_blocks.{block_id}.temp_convs.{layer_in_block_id}", + } assign_to_checkpoint( paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) @@ -385,7 +394,7 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False temopral_conv_paths = [l for l in output_block_layers if "temopral_conv" in l] for path in temopral_conv_paths: pruned_path = path.split("temopral_conv.")[-1] - old_path = ".".join(["output_blocks", str(i), str(block_id) , "temopral_conv", pruned_path]) + old_path = ".".join(["output_blocks", str(i), str(block_id), "temopral_conv", pruned_path]) new_path = ".".join(["up_blocks", str(block_id), "temp_convs", str(layer_in_block_id), pruned_path]) new_checkpoint[new_path] = unet_state_dict[old_path] diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 05aa5350e5cf..eebdbb425718 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -40,8 +40,8 @@ Transformer2DModel, UNet1DModel, UNet2DConditionModel, - UNet3DConditionModel, UNet2DModel, + UNet3DConditionModel, VQModel, ) from .optimization import ( @@ -131,6 +131,7 @@ StableDiffusionUpscalePipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, + TextToVideoMSPipeline, UnCLIPImageVariationPipeline, UnCLIPPipeline, VersatileDiffusionDualGuidedPipeline, diff --git a/src/diffusers/models/transformer_temp.py b/src/diffusers/models/transformer_temp.py index 6735a324ffd2..ab813a178723 100644 --- a/src/diffusers/models/transformer_temp.py +++ b/src/diffusers/models/transformer_temp.py @@ -15,7 +15,6 @@ from typing import Optional import torch -import torch.nn.functional as F from torch import nn from ..configuration_utils import ConfigMixin, register_to_config @@ -225,7 +224,12 @@ def forward( # 3. Output hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states[None, None, :].reshape(batch_size, height, width, channel, num_frames).permute(0, 3, 4, 1, 2).contiguous() + hidden_states = ( + hidden_states[None, None, :] + .reshape(batch_size, height, width, channel, num_frames) + .permute(0, 3, 4, 1, 2) + .contiguous() + ) hidden_states = hidden_states.reshape(batch_frames, channel, height, width) output = hidden_states + residual diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index a3349d071b24..8e2b8d2f2e5f 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -235,11 +235,19 @@ def __init__( self.temp_attentions = nn.ModuleList(temp_attentions) def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, num_frames=1, cross_attention_kwargs=None + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, ): hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames) - for attn, temp_attn, resnet, temp_conv in zip(self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]): + for attn, temp_attn, resnet, temp_conv in zip( + self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:] + ): hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -351,12 +359,20 @@ def __init__( self.gradient_checkpointing = False def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, num_frames=1, cross_attention_kwargs=None + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, ): # TODO(Patrick, William) - attention mask is not used output_states = () - for resnet, temp_conv, attn, temp_attn in zip(self.resnets, self.temp_convs, self.attentions, self.temp_attentions): + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) hidden_states = attn( @@ -563,7 +579,9 @@ def forward( cross_attention_kwargs=None, ): # TODO(Patrick, William) - attention mask is not used - for resnet, temp_conv, attn, temp_attn in zip(self.resnets, self.temp_convs, self.attentions, self.temp_attentions): + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] @@ -660,11 +678,7 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_si class TemporalConvBlock_v2(nn.Module): - def __init__(self, - in_dim, - out_dim=None, - dropout=0.0, - use_image_dataset=False): + def __init__(self, in_dim, out_dim=None, dropout=0.0, use_image_dataset=False): super(TemporalConvBlock_v2, self).__init__() if out_dim is None: out_dim = in_dim # int(1.5*in_dim) @@ -674,17 +688,26 @@ def __init__(self, # conv layers self.conv1 = nn.Sequential( - nn.GroupNorm(32, in_dim), nn.SiLU(), - nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))) + nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)) + ) self.conv2 = nn.Sequential( - nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), - nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), + ) self.conv3 = nn.Sequential( - nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), - nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), + ) self.conv4 = nn.Sequential( - nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), - nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), + ) # zero out the last layer params,so the conv block is identity nn.init.zeros_(self.conv4[-1].weight) diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index c6b1bb825ace..633195e1ca89 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -23,7 +23,15 @@ from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .transformer_temp import TransformerTempModel -from .unet_3d_blocks import get_down_block, get_up_block, UNetMidBlock3DCrossAttn, UpBlock3D, DownBlock3D, CrossAttnUpBlock3D, CrossAttnDownBlock3D +from .unet_3d_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, +) logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -545,7 +553,11 @@ def forward( ) else: sample = upsample_block( - hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, num_frames=num_frames + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + num_frames=num_frames, ) print("3", sample.abs().sum()) @@ -559,6 +571,9 @@ def forward( print("4", sample.abs().sum()) + # reshape to (batch, channel, framerate, width, height) + sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) + if not return_dict: return (sample,) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 5b6c729f80be..1190ad7e1cf3 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -65,6 +65,7 @@ StableUnCLIPPipeline, ) from .stable_diffusion_safe import StableDiffusionPipelineSafe + from .text_to_video_synthesis import TextToVideoMSPipeline from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline from .versatile_diffusion import ( VersatileDiffusionDualGuidedPipeline, diff --git a/src/diffusers/pipelines/text_to_video_synthesis/__init__.py b/src/diffusers/pipelines/text_to_video_synthesis/__init__.py new file mode 100644 index 000000000000..425489c86fdb --- /dev/null +++ b/src/diffusers/pipelines/text_to_video_synthesis/__init__.py @@ -0,0 +1,29 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np + +from ...utils import BaseOutput, OptionalDependencyNotAvailable, is_torch_available, is_transformers_available + + +@dataclass +class TextToVideoMSPipelineOutput(BaseOutput): + """ + Output class for text to video pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + image: np.ndarray + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 +else: + from .pipeline_text_to_video_synth import TextToVideoMSPipeline # noqa: F401 diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index 6caab78040a7..30024538009e 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -18,7 +18,7 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, UNet3DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_accelerate_available, @@ -39,7 +39,7 @@ >>> import torch >>> from diffusers import TextToVideoMSPipeline - >>> pipe = TextToVideoMSPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) + >>> pipe = TextToVideoMSPipeline.from_pretrained("diffusers/ms-text-to-video-1.7b", torch_dtype=torch.float16) >>> pipe = pipe.to("cuda") >>> prompt = "a photo of an astronaut riding a horse on mars" @@ -48,6 +48,18 @@ """ +def tensor2vid(video, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): + mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) # ncfhw + std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) # ncfhw + video = video.mul_(std).add_(mean) # unnormalize back to [0,1] + video.clamp_(0, 1) + i, c, f, h, w = video.shape + images = video.permute(2, 3, 0, 4, 1).reshape(f, h, i * w, c) + images = images.unbind(dim=0) + images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c + return images + + class TextToVideoMSPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using Stable Diffusion. @@ -65,7 +77,7 @@ class TextToVideoMSPipeline(DiffusionPipeline): tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + unet ([`UNet3DConditionModel`]): Conditional U-Net architecture to denoise the encoded video latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. @@ -76,7 +88,7 @@ def __init__( vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, + unet: UNet3DConditionModel, scheduler: KarrasDiffusionSchedulers, ): super().__init__() @@ -125,10 +137,9 @@ def disable_vae_tiling(self): def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - Note that offloading happens on a submodule basis. Memory savings are higher than with - `enable_model_cpu_offload`, but performance is lower. + text_encoder, vae have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded + to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a + submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload @@ -325,23 +336,28 @@ def _encode_prompt( return prompt_embeds - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - else: - has_nsfw_concept = None - return image, has_nsfw_concept - def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) + video = ( + image[None, :] + .reshape( + ( + batch_size, + num_frames, + -1, + ) + + image.shape[2:] + ) + .permute(0, 2, 1, 3, 4) + ) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image + video = video.float() + return video def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -407,8 +423,16 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -431,6 +455,7 @@ def __call__( prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, + num_frames: int = 16, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -558,6 +583,7 @@ def __call__( latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, + num_frames, height, width, prompt_embeds.dtype, @@ -599,24 +625,14 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - if output_type == "latent": - image = latents - has_nsfw_concept = None - elif output_type == "pil": - # 8. Post-processing - image = self.decode_latents(latents) - - # 10. Convert to PIL - image = self.numpy_to_pil(image) - else: - # 8. Post-processing - image = self.decode_latents(latents) + video_tensor = self.decode_latents(latents) + video = tensor2vid(video_tensor) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: - return (image, has_nsfw_concept) + return (video,) - return TextToVideoMSPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return TextToVideoMSPipelineOutput(image=video) From faa4e6df842e88efee12fbe1e7e37d42554b9656 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 21 Mar 2023 20:50:54 +0000 Subject: [PATCH 10/43] finish --- src/diffusers/models/unet_3d_condition.py | 9 --------- .../pipeline_text_to_video_synth.py | 2 +- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index 633195e1ca89..e1dce2a07db6 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -482,7 +482,6 @@ def forward( sample = self.transformer_in(sample, num_frames=num_frames).sample # 3. down - print("0", sample.abs().sum()) down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: @@ -510,8 +509,6 @@ def forward( down_block_res_samples = new_down_block_res_samples - print("1", sample.abs().sum()) - # 4. mid if self.mid_block is not None: sample = self.mid_block( @@ -526,8 +523,6 @@ def forward( if mid_block_additional_residual is not None: sample = sample + mid_block_additional_residual - print("2", sample.abs().sum()) - # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 @@ -560,8 +555,6 @@ def forward( num_frames=num_frames, ) - print("3", sample.abs().sum()) - # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample) @@ -569,8 +562,6 @@ def forward( sample = self.conv_out(sample) - print("4", sample.abs().sum()) - # reshape to (batch, channel, framerate, width, height) sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index 30024538009e..cdbe882cf819 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -457,7 +457,7 @@ def __call__( width: Optional[int] = None, num_frames: int = 16, num_inference_steps: int = 50, - guidance_scale: float = 7.5, + guidance_scale: float = 9.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, From e27769b62bed2de5240c374602be4209e5db5505 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 22 Mar 2023 09:59:33 +0530 Subject: [PATCH 11/43] refactor video output class. --- .../text_to_video_synthesis/__init__.py | 9 +++--- .../pipeline_text_to_video_synth.py | 28 +++++++++++-------- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/src/diffusers/pipelines/text_to_video_synthesis/__init__.py b/src/diffusers/pipelines/text_to_video_synthesis/__init__.py index 425489c86fdb..c1a75a38927e 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/__init__.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/__init__.py @@ -12,12 +12,13 @@ class TextToVideoMSPipelineOutput(BaseOutput): Output class for text to video pipelines. Args: - images (`List[PIL.Image.Image]` or `np.ndarray`) - List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, - num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + frames (`List[np.ndarray]`) + List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)`. + NumPy array present the denoised images of the diffusion pipeline. The length of the list denotes the video + length i.e., the number of frames. """ - image: np.ndarray + frames: List[np.ndarray] try: diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index cdbe882cf819..a182239e0f2b 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -15,6 +15,7 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Union +import numpy as np import torch from transformers import CLIPTextModel, CLIPTokenizer @@ -42,20 +43,25 @@ >>> pipe = TextToVideoMSPipeline.from_pretrained("diffusers/ms-text-to-video-1.7b", torch_dtype=torch.float16) >>> pipe = pipe.to("cuda") - >>> prompt = "a photo of an astronaut riding a horse on mars" - >>> image = pipe(prompt).images[0] + >>> prompt = "Spiderman is surfing" + >>> video_frames = pipe(prompt).frames ``` """ -def tensor2vid(video, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): - mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) # ncfhw - std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) # ncfhw - video = video.mul_(std).add_(mean) # unnormalize back to [0,1] +def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]: + # reshape to ncfhw + mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) + std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) + # unnormalize back to [0,1] + video = video.mul_(std).add_(mean) video.clamp_(0, 1) + # prepare the final outputs i, c, f, h, w = video.shape - images = video.permute(2, 3, 0, 4, 1).reshape(f, h, i * w, c) - images = images.unbind(dim=0) + images = video.permute(2, 3, 0, 4, 1).reshape( + f, h, i * w, c + ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c) + images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames) images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c return images @@ -536,9 +542,7 @@ def __call__( Returns: [`~pipelines.stable_diffusion.TextToVideoMSPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.TextToVideoMSPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. + When returning a tuple, the first element is a list with the generated frames. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor @@ -635,4 +639,4 @@ def __call__( if not return_dict: return (video,) - return TextToVideoMSPipelineOutput(image=video) + return TextToVideoMSPipelineOutput(frames=video) From d5e544fdca8a72619eb98dffab0aa0c2e7caad41 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 22 Mar 2023 10:11:38 +0530 Subject: [PATCH 12/43] feat: add support for a video export utility. --- .../pipeline_text_to_video_synth.py | 3 +++ src/diffusers/utils/__init__.py | 3 +++ src/diffusers/utils/import_utils.py | 17 ++++++++++++ src/diffusers/utils/testing_utils.py | 26 +++++++++++++++++-- 4 files changed, 47 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index a182239e0f2b..d69293e3dcf9 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -39,12 +39,15 @@ ```py >>> import torch >>> from diffusers import TextToVideoMSPipeline + >>> from diffusers.utils import export_to_video >>> pipe = TextToVideoMSPipeline.from_pretrained("diffusers/ms-text-to-video-1.7b", torch_dtype=torch.float16) >>> pipe = pipe.to("cuda") >>> prompt = "Spiderman is surfing" >>> video_frames = pipe(prompt).frames + >>> video_path = export_to_video(video_frames) + >>> video_path ``` """ diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 196b3b0279d0..a43744db9bb8 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -57,6 +57,7 @@ is_librosa_available, is_omegaconf_available, is_onnx_available, + is_opencv_available, is_safetensors_available, is_scipy_available, is_tensorboard_available, @@ -92,6 +93,8 @@ torch_device, ) +if is_opencv_available(): + from .testing_utils import export_to_video logger = get_logger(__name__) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index b3c6d1824369..337828fd3e07 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -169,6 +169,12 @@ if _onnx_available: logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}") +_opencv_available = importlib.util.find_spec("cv2") is not None +try: + _opencv_version = importlib_metadata.version("cv2") + logger.debug(f"Successfully imported cv2 version {_opencv_version}") +except importlib_metadata.PackageNotFoundError: + _opencv_available = False _scipy_available = importlib.util.find_spec("scipy") is not None try: @@ -272,6 +278,10 @@ def is_onnx_available(): return _onnx_available +def is_opencv_available(): + return _opencv_available + + def is_scipy_available(): return _scipy_available @@ -332,6 +342,12 @@ def is_compel_available(): install onnxruntime` """ +# docstyle-ignore +OPENCV_IMPORT_ERROR = """ +{0} requires the OpenCV library but it was not found in your environment. You can install it with pip: `pip +install opencv-python` +""" + # docstyle-ignore SCIPY_IMPORT_ERROR = """ {0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install @@ -391,6 +407,7 @@ def is_compel_available(): ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)), + ("opencv", (is_opencv_available, OPENCV_IMPORT_ERROR)), ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)), diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index cea2869b3193..65ea371b85f6 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -3,12 +3,13 @@ import os import random import re +import tempfile import unittest import urllib.parse from distutils.util import strtobool from io import BytesIO, StringIO from pathlib import Path -from typing import Optional, Union +from typing import List, Optional, Union import numpy as np import PIL.Image @@ -16,7 +17,13 @@ import requests from packaging import version -from .import_utils import is_compel_available, is_flax_available, is_onnx_available, is_torch_available +from .import_utils import ( + is_compel_available, + is_flax_available, + is_onnx_available, + is_opencv_available, + is_torch_available, +) from .logging import get_logger @@ -253,6 +260,21 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: return image +def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str: + if is_opencv_available(): + import cv2 + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name + + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + h, w, c = video_frames[0].shape + video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=8, frameSize=(w, h)) + for i in range(len(video_frames)): + img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR) + video_writer.write(img) + return output_video_path + + def load_hf_numpy(path) -> np.ndarray: if not path.startswith("http://") or path.startswith("https://"): path = os.path.join( From 5945729fb16088861b8eee3e8a6aff03ddf06fa6 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 22 Mar 2023 10:54:54 +0530 Subject: [PATCH 13/43] fix: opencv availability check. --- src/diffusers/utils/import_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 337828fd3e07..3c09cb24f965 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -169,9 +169,11 @@ if _onnx_available: logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}") -_opencv_available = importlib.util.find_spec("cv2") is not None +# (sayakpaul): importlib.util.find_spec("opencv-python") returns None even when it's installed. +# _opencv_available = importlib.util.find_spec("opencv-python") is not None try: - _opencv_version = importlib_metadata.version("cv2") + _opencv_version = importlib_metadata.version("opencv-python") + _opencv_available = True logger.debug(f"Successfully imported cv2 version {_opencv_version}") except importlib_metadata.PackageNotFoundError: _opencv_available = False From 5251c3a4e11aca83d2b3ef10eeb86daaf545807a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 22 Mar 2023 11:22:25 +0530 Subject: [PATCH 14/43] run make fix-copies. --- src/diffusers/utils/dummy_pt_objects.py | 15 +++++++++++++++ .../utils/dummy_torch_and_transformers_objects.py | 15 +++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index c731a1f1ddf3..700a3080fa11 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -122,6 +122,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class UNet3DConditionModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class VQModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 1b0f812ad16c..c36b5efd9ab3 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -347,6 +347,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class TextToVideoMSPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class UnCLIPImageVariationPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From cf8ac80a73e0d1a71a54b262d3bd118bfb1b81da Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 22 Mar 2023 11:29:57 +0530 Subject: [PATCH 15/43] add: docs for the model components. --- docs/source/en/api/models.mdx | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/docs/source/en/api/models.mdx b/docs/source/en/api/models.mdx index dc425e98628c..c4820deb1e74 100644 --- a/docs/source/en/api/models.mdx +++ b/docs/source/en/api/models.mdx @@ -37,6 +37,12 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module ## UNet2DConditionModel [[autodoc]] UNet2DConditionModel +## UNet3DConditionOutput +[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput + +## UNet3DConditionModel +[[autodoc]] UNet3DConditionModel + ## DecoderOutput [[autodoc]] models.vae.DecoderOutput @@ -58,6 +64,12 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module ## Transformer2DModelOutput [[autodoc]] models.transformer_2d.Transformer2DModelOutput +## TransformerTempModel +[[autodoc]] TransformerTempModel + +## Transformer2DModelOutput +[[autodoc]] models.transformer_temp.TransformerTempModelOutput + ## PriorTransformer [[autodoc]] models.prior_transformer.PriorTransformer From 7a807642ce773a32f8e6b849b2fb5c0222025ce8 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 22 Mar 2023 11:54:11 +0530 Subject: [PATCH 16/43] add: standalone pipeline doc. --- docs/source/en/_toctree.yml | 2 + .../source/en/api/pipelines/text_to_video.mdx | 102 ++++++++++++++++++ .../pipeline_text_to_video_synth.py | 2 +- 3 files changed, 105 insertions(+), 1 deletion(-) create mode 100644 docs/source/en/api/pipelines/text_to_video.mdx diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 09012a5c693d..c3ce62ac05df 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -190,6 +190,8 @@ title: Stable unCLIP - local: api/pipelines/stochastic_karras_ve title: Stochastic Karras VE + - local: api/pipelines/text_to_video + title: Text-to-Video - local: api/pipelines/unclip title: UnCLIP - local: api/pipelines/latent_diffusion_uncond diff --git a/docs/source/en/api/pipelines/text_to_video.mdx b/docs/source/en/api/pipelines/text_to_video.mdx new file mode 100644 index 000000000000..fd77f346b46e --- /dev/null +++ b/docs/source/en/api/pipelines/text_to_video.mdx @@ -0,0 +1,102 @@ + + +# Text-to-video synthesis + +Text-to-video synthesis from [ModelScope](https://modelscope.cn/) can be considered the same as Stable Diffusion structure-wise but it is extended to videos instead of static images. More specifically, this system allows us to generate videos from a natural language text prompt. + +From the [model summary](https://huggingface.co/damo-vilab/modelscope-damo-text-to-video-synthesis): + +*This model is based on a multi-stage text-to-video generation diffusion model, which inputs a description text and returns a video that matches the text description. Only English input is supported.* + +Resources: + +* [Website](https://modelscope.cn/models/damo/text-to-video-synthesis/summary) +* [GitHub repository](https://github.com/modelscope/modelscope/) +* [Spaces] (TODO) + +## Available Pipelines: + +| Pipeline | Tasks | Demo +|---|---|:---:| +| [TextToVideoMSPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py) | *Text-to-Video Generation* | [Spaces] (TODO) + +## Usage example + +Let's start by generating a short video: + +```python +import torch +from diffusers import TextToVideoMSPipeline +from diffusers.utils import export_to_video + +pipe = TextToVideoMSPipeline.from_pretrained("diffusers/ms-text-to-video-1.7b", torch_dtype=torch.float16) +pipe = pipe.to("cuda") + +prompt = "Spiderman is surfing" +video_frames = pipe(prompt).frames +video_path = export_to_video(video_frames) +video_path +``` + +Diffusers supports different optimization techniques to for improving the latency +and memory footprint of a pipeline. Since videos are often more memory-heavy than images, +for this pipeline, we can enable CPU offloading and VAE slicing to keep the memory-footprint at bay. + +Let's generate a video of 8 seconds with CPU offloading and VAE slicing: + +```python +import torch +from diffusers import TextToVideoMSPipeline +from diffusers.utils import export_to_video + +pipe = TextToVideoMSPipeline.from_pretrained("diffusers/ms-text-to-video-1.7b", torch_dtype=torch.float16) +pipe = pipe.to("cuda") + +# memory optimization +pipe.enable_model_cpu_offload() +pipe.enable_vae_slicing() + +prompt = "Darth Vader surfing a wave" +video_frames = pipe(prompt, num_frames=64, num_inference_steps=25).frames +video_path = export_to_video(video_frames) +video_path +``` + +Together with PyTorch 2.0, "fp16" as the precision and the above techniques, it just takes 7 GBs of GPU memory. + +We can also use a different scheduler easily: + +```python +import torch +from diffusers import TextToVideoMSPipeline, DPMSolverMultistepScheduler +from diffusers.utils import export_to_video + +pipe = TextToVideoMSPipeline.from_pretrained("diffusers/ms-text-to-video-1.7b", torch_dtype=torch.float16) +pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) +pipe = pipe.to("cuda") + +prompt = "Spiderman is surfing" +video_frames = pipe(prompt).frames +video_path = export_to_video(video_frames) +video_path +``` + +## Available checkpoints + +* [diffusers/ms-text-to-video-sd](https://huggingface.co/diffusers/ms-text-to-video-sd/) +* [diffusers/ms-text-to-video-1.7b](https://huggingface.co/diffusers/ms-text-to-video-1.7b) + +## TextToVideoMSPipeline +[[autodoc]] TextToVideoMSPipeline + - all + - __call__ \ No newline at end of file diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index d69293e3dcf9..1839204b5486 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -71,7 +71,7 @@ def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) - class TextToVideoMSPipeline(DiffusionPipeline): r""" - Pipeline for text-to-image generation using Stable Diffusion. + Pipeline for text-to-video generation. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) From f5b3fe4a7200d576383e65c36f992c56d7c64413 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 22 Mar 2023 12:01:31 +0530 Subject: [PATCH 17/43] edit docstring of the pipeline. --- .../text_to_video_synthesis/pipeline_text_to_video_synth.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index 1839204b5486..ee91b692428e 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -80,9 +80,7 @@ class TextToVideoMSPipeline(DiffusionPipeline): vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + Frozen text-encoder. Same as Stable Diffusion 2. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). From fb916ba0562c81120874b6b1eda754323e49dfab Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 22 Mar 2023 12:09:48 +0530 Subject: [PATCH 18/43] add: right path to TransformerTempModel --- docs/source/en/api/models.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/models.mdx b/docs/source/en/api/models.mdx index c4820deb1e74..2a24e8cf95f1 100644 --- a/docs/source/en/api/models.mdx +++ b/docs/source/en/api/models.mdx @@ -65,7 +65,7 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module [[autodoc]] models.transformer_2d.Transformer2DModelOutput ## TransformerTempModel -[[autodoc]] TransformerTempModel +[[autodoc]] models.transformer_temp.TransformerTempModel ## Transformer2DModelOutput [[autodoc]] models.transformer_temp.TransformerTempModelOutput From 880cfce82b1cdf6f52a95049f1b64d637be62a24 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 22 Mar 2023 15:03:58 +0530 Subject: [PATCH 19/43] add: first set of tests. --- .../pipeline_text_to_video_synth.py | 6 +- tests/pipelines/text_to_video/__init__.py | 0 .../text_to_video/test_text_to_video.py | 268 ++++++++++++++++++ 3 files changed, 269 insertions(+), 5 deletions(-) create mode 100644 tests/pipelines/text_to_video/__init__.py create mode 100644 tests/pipelines/text_to_video/test_text_to_video.py diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index ee91b692428e..0cdb83eadc49 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -219,7 +219,7 @@ def _encode_prompt( Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -472,7 +472,6 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, @@ -521,9 +520,6 @@ def __call__( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.TextToVideoMSPipelineOutput`] instead of a plain tuple. diff --git a/tests/pipelines/text_to_video/__init__.py b/tests/pipelines/text_to_video/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/text_to_video/test_text_to_video.py b/tests/pipelines/text_to_video/test_text_to_video.py new file mode 100644 index 000000000000..e6ee51294c26 --- /dev/null +++ b/tests/pipelines/text_to_video/test_text_to_video.py @@ -0,0 +1,268 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + TextToVideoMSPipeline, + UNet3DConditionModel, +) +from diffusers.utils import torch_device + +from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class TextToVideoMSPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = TextToVideoMSPipeline + params = TEXT_TO_IMAGE_PARAMS + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + # No `output_type`. + required_optional_params = frozenset( + [ + "num_inference_steps", + "num_images_per_prompt", + "generator", + "latents", + "return_dict", + "callback", + "callback_steps", + ] + ) + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet3DConditionModel( + block_out_channels=(32, 64, 64, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D"), + up_block_types=("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"), + cross_attention_dim=32, + attention_head_dim=4, + use_linear_projection=True, + ) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + sample_size=128, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=512, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + } + return inputs + + def test_text_to_video_ddim(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = TextToVideoMSPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + frames = sd_pipe(**inputs).frames + image_slice = frames[0][-3:, -3:, -1] + + slice = [round(x, 4) for x in image_slice.flatten().tolist()] + print(",".join([str(x) for x in slice])) + + assert frames[0].shape == (64, 64, 3) + expected_slice = np.array([166, 184, 167, 118, 102, 123, 108, 93, 114]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + # (todo): sayakpaul + @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.") + def test_inference_batch_consistent(self): + pass + + # (todo): sayakpaul + @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.") + def test_inference_batch_single_identical(self): + pass + + @unittest.skip(reason="`num_images_per_prompt` argument is not supported for this pipeline.") + def test_num_images_per_prompt(self): + pass + + # Overriding since the output type for this pipeline differs from that of + # text-to-image pipelines. + def test_attention_slicing_forward_pass(self): + self._test_attention_slicing_forward_pass() + + def _test_attention_slicing_forward_pass(self, expected_max_diff=1e-3): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + # Warmup pass when using mps (see #372) + if torch_device == "mps": + _ = pipe(**self.get_dummy_inputs(torch_device)) + + inputs = self.get_dummy_inputs(torch_device) + output_without_slicing = pipe(**inputs).frames[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(torch_device) + output_with_slicing = pipe(**inputs).frames[0] + + max_diff = np.abs((output_with_slicing / 255.0) - (output_without_slicing / 255.0)).max() + self.assertLess(max_diff, expected_max_diff, "Attention slicing should not affect the inference results") + + avg_diff = np.abs(output_without_slicing - output_without_slicing).mean() + self.assertLess(avg_diff, 10, f"Error image deviates {avg_diff} pixels on average") + + # Overriding since the output type for this pipeline differs from that of + # text-to-image pipelines. + def test_dict_tuple_outputs_equivalent(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + # Warmup pass when using mps (see #372) + if torch_device == "mps": + _ = pipe(**self.get_dummy_inputs(torch_device)) + + output = pipe(**self.get_dummy_inputs(torch_device)).frames[0] + output_tuple = pipe(**self.get_dummy_inputs(torch_device), return_dict=False)[0][0] + + max_diff = np.abs(output / 255.0 - output_tuple / 255.0).max() + self.assertLess(max_diff, 1e-4) + + def test_save_load_local(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + # Warmup pass when using mps (see #372) + if torch_device == "mps": + _ = pipe(**self.get_dummy_inputs(torch_device)) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs).frames[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output_loaded = pipe_loaded(**inputs).frames[0] + + max_diff = np.abs((output / 255.0) - (output_loaded / 255.0)).max() + self.assertLess(max_diff, 1e-4) + + def test_save_load_optional_components(self): + if not hasattr(self.pipeline_class, "_optional_components"): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + # Warmup pass when using mps (see #372) + if torch_device == "mps": + _ = pipe(**self.get_dummy_inputs(torch_device)) + + # set all optional components to None + for optional_component in pipe._optional_components: + setattr(pipe, optional_component, None) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs).frames[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + for optional_component in pipe._optional_components: + self.assertTrue( + getattr(pipe_loaded, optional_component) is None, + f"`{optional_component}` did not stay set to None after loading.", + ) + + inputs = self.get_dummy_inputs(torch_device) + output_loaded = pipe_loaded(**inputs).frames[0] + + max_diff = np.abs((output / 255.0) - (output_loaded / 255.0)).max() + self.assertLess(max_diff, 1e-4) From 6f0f5e351aceeb4341cc57c86f537a91f2722bb4 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 22 Mar 2023 15:38:38 +0530 Subject: [PATCH 20/43] complete fast tests for text to video. --- .../text_to_video/test_text_to_video.py | 58 +++++++++++++++++-- 1 file changed, 52 insertions(+), 6 deletions(-) diff --git a/tests/pipelines/text_to_video/test_text_to_video.py b/tests/pipelines/text_to_video/test_text_to_video.py index e6ee51294c26..269293a341c5 100644 --- a/tests/pipelines/text_to_video/test_text_to_video.py +++ b/tests/pipelines/text_to_video/test_text_to_video.py @@ -23,10 +23,11 @@ from diffusers import ( AutoencoderKL, DDIMScheduler, + DPMSolverMultistepScheduler, TextToVideoMSPipeline, UNet3DConditionModel, ) -from diffusers.utils import torch_device +from diffusers.utils import skip_mps, torch_device from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS from ...test_pipelines_common import PipelineTesterMixin @@ -122,7 +123,7 @@ def get_dummy_inputs(self, device, seed=0): } return inputs - def test_text_to_video_ddim(self): + def test_text_to_video_default_case(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() sd_pipe = TextToVideoMSPipeline(**components) @@ -133,14 +134,47 @@ def test_text_to_video_ddim(self): frames = sd_pipe(**inputs).frames image_slice = frames[0][-3:, -3:, -1] - slice = [round(x, 4) for x in image_slice.flatten().tolist()] - print(",".join([str(x) for x in slice])) - assert frames[0].shape == (64, 64, 3) expected_slice = np.array([166, 184, 167, 118, 102, 123, 108, 93, 114]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_pix2pix_negative_prompt(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = TextToVideoMSPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + negative_prompt = "french fries" + frames = sd_pipe(**inputs, negative_prompt=negative_prompt).frames + image_slice = frames[0][-3:, -3:, -1] + + assert frames[0].shape == (64, 64, 3) + expected_slice = np.array([166, 181, 167, 119, 99, 124, 110, 94, 114]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_stable_diffusion_pix2pix_dpm_multistep(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + components["scheduler"] = DPMSolverMultistepScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" + ) + sd_pipe = TextToVideoMSPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + frames = sd_pipe(**inputs).frames + image_slice = frames[0][-3:, -3:, -1] + + assert frames[0].shape == (64, 64, 3) + expected_slice = np.array([170, 190, 180, 140, 121, 136, 121, 97, 122]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + # (todo): sayakpaul @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.") def test_inference_batch_consistent(self): @@ -157,10 +191,11 @@ def test_num_images_per_prompt(self): # Overriding since the output type for this pipeline differs from that of # text-to-image pipelines. + @skip_mps def test_attention_slicing_forward_pass(self): self._test_attention_slicing_forward_pass() - def _test_attention_slicing_forward_pass(self, expected_max_diff=1e-3): + def _test_attention_slicing_forward_pass(self, expected_max_diff=4e-3): if not self.test_attention_slicing: return @@ -188,6 +223,7 @@ def _test_attention_slicing_forward_pass(self, expected_max_diff=1e-3): # Overriding since the output type for this pipeline differs from that of # text-to-image pipelines. + @skip_mps def test_dict_tuple_outputs_equivalent(self): components = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -204,6 +240,13 @@ def test_dict_tuple_outputs_equivalent(self): max_diff = np.abs(output / 255.0 - output_tuple / 255.0).max() self.assertLess(max_diff, 1e-4) + @skip_mps + def test_progress_bar(self): + return super().test_progress_bar() + + # Overriding since the output type for this pipeline differs from that of + # text-to-image pipelines. + @skip_mps def test_save_load_local(self): components = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -229,6 +272,9 @@ def test_save_load_local(self): max_diff = np.abs((output / 255.0) - (output_loaded / 255.0)).max() self.assertLess(max_diff, 1e-4) + # Overriding since the output type for this pipeline differs from that of + # text-to-image pipelines. + @skip_mps def test_save_load_optional_components(self): if not hasattr(self.pipeline_class, "_optional_components"): return From d58cb7fa503cdf11e78769560d2148a4ab83ee56 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 22 Mar 2023 11:35:27 +0000 Subject: [PATCH 21/43] fix bug --- src/diffusers/models/unet_3d_blocks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index 8e2b8d2f2e5f..3ad101cb31cd 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -185,8 +185,8 @@ def __init__( for _ in range(num_layers): attentions.append( Transformer2DModel( - attn_num_head_channels, in_channels // attn_num_head_channels, + attn_num_head_channels, in_channels=in_channels, num_layers=1, cross_attention_dim=cross_attention_dim, @@ -197,8 +197,8 @@ def __init__( ) temp_attentions.append( TransformerTempModel( - attn_num_head_channels, in_channels // attn_num_head_channels, + attn_num_head_channels, in_channels=in_channels, num_layers=1, cross_attention_dim=cross_attention_dim, From 387181c398a6d1dbed9702a6e2b37a66c2d8db6d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 22 Mar 2023 11:54:07 +0000 Subject: [PATCH 22/43] up --- src/diffusers/models/unet_3d_condition.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index e1dce2a07db6..ae32e11f9b2b 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -1,4 +1,5 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. +# Copyright 2023 The ModelScope Team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 0a9c4951d566d6513fc69fcf81bc2c99935e2cb0 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 22 Mar 2023 18:32:14 +0530 Subject: [PATCH 23/43] three fast tests failing. --- .../source/en/api/pipelines/text_to_video.mdx | 21 + tests/models/test_models_unet_3d_condition.py | 394 ++++++++++++++++++ 2 files changed, 415 insertions(+) create mode 100644 tests/models/test_models_unet_3d_condition.py diff --git a/docs/source/en/api/pipelines/text_to_video.mdx b/docs/source/en/api/pipelines/text_to_video.mdx index fd77f346b46e..a54fda575c21 100644 --- a/docs/source/en/api/pipelines/text_to_video.mdx +++ b/docs/source/en/api/pipelines/text_to_video.mdx @@ -91,6 +91,27 @@ video_path = export_to_video(video_frames) video_path ``` +Here are sample outputs: + + + + + + +
+ An astronaut riding a horse. +
+ An astronaut riding a horse. +
+ Darth vader surfing in waves. +
+ Darth vader surfing in waves. +
+ ## Available checkpoints * [diffusers/ms-text-to-video-sd](https://huggingface.co/diffusers/ms-text-to-video-sd/) diff --git a/tests/models/test_models_unet_3d_condition.py b/tests/models/test_models_unet_3d_condition.py new file mode 100644 index 000000000000..9acfc456fbf5 --- /dev/null +++ b/tests/models/test_models_unet_3d_condition.py @@ -0,0 +1,394 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from diffusers.models import ModelMixin, UNet3DConditionModel +from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.utils import ( + floats_tensor, + logging, + torch_all_close, + torch_device, +) +from diffusers.utils.import_utils import is_xformers_available + +from ..test_modeling_common import ModelTesterMixin + + +logger = logging.get_logger(__name__) +torch.backends.cuda.matmul.allow_tf32 = False + + +def create_lora_layers(model): + lora_attn_procs = {} + for name in model.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = model.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(model.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = model.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + lora_attn_procs[name] = lora_attn_procs[name].to(model.device) + + # add 1 to weights to mock trained weights + with torch.no_grad(): + lora_attn_procs[name].to_q_lora.up.weight += 1 + lora_attn_procs[name].to_k_lora.up.weight += 1 + lora_attn_procs[name].to_v_lora.up.weight += 1 + lora_attn_procs[name].to_out_lora.up.weight += 1 + + return lora_attn_procs + + +class UNet3DConditionModelTests(ModelTesterMixin, unittest.TestCase): + model_class = UNet3DConditionModel + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 4 + num_frames = 4 + sizes = (32, 32) + + noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) + time_step = torch.tensor([10]).to(torch_device) + encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device) + + return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} + + @property + def input_shape(self): + return (4, 4, 32, 32) + + @property + def output_shape(self): + return (4, 4, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_out_channels": (32, 64, 64, 64), + "down_block_types": ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + "up_block_types": ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"), + "cross_attention_dim": 32, + "attention_head_dim": 4, + "out_channels": 4, + "in_channels": 4, + "layers_per_block": 2, + "sample_size": 32, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_enable_works(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + + model.enable_xformers_memory_efficient_attention() + + assert ( + model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__ + == "XFormersAttnProcessor" + ), "xformers is not enabled" + + @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS") + def test_gradient_checkpointing(self): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + + assert not model.is_gradient_checkpointing and model.training + + out = model(**inputs_dict).sample + # run the backwards pass on the model. For backwards pass, for simplicity purpose, + # we won't calculate the loss and rather backprop on out.sum() + model.zero_grad() + + labels = torch.randn_like(out) + loss = (out - labels).mean() + loss.backward() + + # re-instantiate the model now enabling gradient checkpointing + model_2 = self.model_class(**init_dict) + # clone model + model_2.load_state_dict(model.state_dict()) + model_2.to(torch_device) + model_2.enable_gradient_checkpointing() + + assert model_2.is_gradient_checkpointing and model_2.training + + out_2 = model_2(**inputs_dict).sample + # run the backwards pass on the model. For backwards pass, for simplicity purpose, + # we won't calculate the loss and rather backprop on out.sum() + model_2.zero_grad() + loss_2 = (out_2 - labels).mean() + loss_2.backward() + + # compare the output and parameters gradients + self.assertTrue((loss - loss_2).abs() < 1e-4) + named_params = dict(model.named_parameters()) + named_params_2 = dict(model_2.named_parameters()) + for name, param in named_params.items(): + self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-4)) + + # Overriding because `block_out_channels` needs to be different for this model. + def test_forward_with_norm_groups(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["norm_num_groups"] = 32 + init_dict["block_out_channels"] = (32, 64, 64, 64) + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + # Overriding since the UNet3D outputs a different structure. + def test_determinism(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + # Warmup pass when using mps (see #372) + if torch_device == "mps" and isinstance(model, ModelMixin): + model(**self.dummy_input) + + first = model(**inputs_dict) + if isinstance(first, dict): + first = first.sample + + second = model(**inputs_dict) + if isinstance(second, dict): + second = second.sample + + out_1 = first.cpu().numpy() + out_2 = second.cpu().numpy() + out_1 = out_1[~np.isnan(out_1)] + out_2 = out_2[~np.isnan(out_2)] + max_diff = np.amax(np.abs(out_1 - out_2)) + self.assertLessEqual(max_diff, 1e-5) + + def test_model_with_attention_head_dim_tuple(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16, 16, 16) + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_model_with_use_linear_projection(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["use_linear_projection"] = True + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_model_attention_slicing(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = 8 + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + model.set_attention_slice("auto") + with torch.no_grad(): + output = model(**inputs_dict) + assert output is not None + + model.set_attention_slice("max") + with torch.no_grad(): + output = model(**inputs_dict) + assert output is not None + + model.set_attention_slice(2) + with torch.no_grad(): + output = model(**inputs_dict) + assert output is not None + + def test_model_slicable_head_dim(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16, 16, 16) + + model = self.model_class(**init_dict) + + def check_slicable_dim_attr(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + assert isinstance(module.sliceable_head_dim, int) + + for child in module.children(): + check_slicable_dim_attr(child) + + # retrieve number of attention layers + for module in model.children(): + check_slicable_dim_attr(module) + + def test_special_attn_proc(self): + class AttnEasyProc(torch.nn.Module): + def __init__(self, num): + super().__init__() + self.weight = torch.nn.Parameter(torch.tensor(num)) + self.is_run = False + self.number = 0 + self.counter = 0 + + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + query = attn.to_q(hidden_states) + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states += self.weight + + self.is_run = True + self.counter += 1 + self.number = number + + return hidden_states + + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16, 16, 16) + + model = self.model_class(**init_dict) + model.to(torch_device) + + processor = AttnEasyProc(5.0) + + model.set_attn_processor(processor) + model(**inputs_dict, cross_attention_kwargs={"number": 123}).sample + + assert processor.counter == 12 + assert processor.is_run + assert processor.number == 123 + + # (`attn_processors`) needs to be implemented in this model for this test. + # def test_lora_processors(self): + + # (`attn_processors`) needs to be implemented in this model for this test. + # def test_lora_save_load(self): + + # (`attn_processors`) needs to be implemented for this test in the model. + # def test_lora_save_load_safetensors(self): + + # (`attn_processors`) needs to be implemented for this test in the model. + # def test_lora_save_safetensors_load_torch(self): + + # (`attn_processors`) needs to be implemented for this test. + # def test_lora_save_torch_force_load_safetensors_error(self): + + # (`attn_processors`) needs to be added for this test. + # def test_lora_on_off(self): + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_lora_xformers_on_off(self): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = 4 + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + lora_attn_procs = create_lora_layers(model) + model.set_attn_processor(lora_attn_procs) + + # default + with torch.no_grad(): + sample = model(**inputs_dict).sample + + model.enable_xformers_memory_efficient_attention() + on_sample = model(**inputs_dict).sample + + model.disable_xformers_memory_efficient_attention() + off_sample = model(**inputs_dict).sample + + assert (sample - on_sample).abs().max() < 1e-4 + assert (sample - off_sample).abs().max() < 1e-4 From 50e895042b48960f5483bd9fc2a4842095f2e636 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 22 Mar 2023 18:32:59 +0530 Subject: [PATCH 24/43] add: note on slow tests --- tests/models/test_models_unet_3d_condition.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/models/test_models_unet_3d_condition.py b/tests/models/test_models_unet_3d_condition.py index 9acfc456fbf5..52a4da5f65fa 100644 --- a/tests/models/test_models_unet_3d_condition.py +++ b/tests/models/test_models_unet_3d_condition.py @@ -392,3 +392,6 @@ def test_lora_xformers_on_off(self): assert (sample - on_sample).abs().max() < 1e-4 assert (sample - off_sample).abs().max() < 1e-4 + + +# (todo: sayakpaul) implement SLOW tests. \ No newline at end of file From 479967034e2e8592120cc4c925e5a4a61b9e93c7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 22 Mar 2023 13:04:15 +0000 Subject: [PATCH 25/43] make work with all schedulers --- .../pipeline_text_to_video_synth.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index 0cdb83eadc49..7f08d25defc3 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -617,9 +617,17 @@ def __call__( noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + # reshape latents + bsz, channel, frames, width, height = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) + noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) + # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + # reshape latents back + latents = latents[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4) + # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() From b131d48b3ec87c849a4ab24f7a25365ae7617037 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 22 Mar 2023 18:36:51 +0530 Subject: [PATCH 26/43] apply styling. --- tests/models/test_models_unet_3d_condition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_models_unet_3d_condition.py b/tests/models/test_models_unet_3d_condition.py index 52a4da5f65fa..794c1c5a0cab 100644 --- a/tests/models/test_models_unet_3d_condition.py +++ b/tests/models/test_models_unet_3d_condition.py @@ -394,4 +394,4 @@ def test_lora_xformers_on_off(self): assert (sample - off_sample).abs().max() < 1e-4 -# (todo: sayakpaul) implement SLOW tests. \ No newline at end of file +# (todo: sayakpaul) implement SLOW tests. From bd50840c128ac408c9b99d5ef2dfc6520ba7c9bf Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 22 Mar 2023 13:45:36 +0000 Subject: [PATCH 27/43] add slow tests --- .../pipeline_text_to_video_synth.py | 7 +++- .../text_to_video/test_text_to_video.py | 38 ++++++++++++++++++- 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index 7f08d25defc3..301c28d36b56 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -472,6 +472,7 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "np", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, @@ -635,7 +636,11 @@ def __call__( callback(i, t, latents) video_tensor = self.decode_latents(latents) - video = tensor2vid(video_tensor) + + if output_type == "pt": + video = video_tensor + else: + video = tensor2vid(video_tensor) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/tests/pipelines/text_to_video/test_text_to_video.py b/tests/pipelines/text_to_video/test_text_to_video.py index 269293a341c5..62686567e5c5 100644 --- a/tests/pipelines/text_to_video/test_text_to_video.py +++ b/tests/pipelines/text_to_video/test_text_to_video.py @@ -27,7 +27,7 @@ TextToVideoMSPipeline, UNet3DConditionModel, ) -from diffusers.utils import skip_mps, torch_device +from diffusers.utils import load_numpy, skip_mps, slow, torch_device from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS from ...test_pipelines_common import PipelineTesterMixin @@ -312,3 +312,39 @@ def test_save_load_optional_components(self): max_diff = np.abs((output / 255.0) - (output_loaded / 255.0)).max() self.assertLess(max_diff, 1e-4) + + +@slow +class TextToVideoMSPipelineSlowTests(unittest.TestCase): + def test_full_model(self): + expected_video = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_to_video/video.npy" + ) + + pipe = TextToVideoMSPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b") + pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + pipe = pipe.to("cuda") + + prompt = "Spiderman is surfing" + generator = torch.Generator(device="cpu").manual_seed(0) + + video_frames = pipe(prompt, generator=generator, num_inference_steps=25, output_type="pt").frames + video = video_frames.cpu().numpy() + + assert np.abs(expected_video - video).mean() < 5e-2 + + def test_two_step_model(self): + expected_video = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_to_video/video_2step.npy" + ) + + pipe = TextToVideoMSPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b") + pipe = pipe.to("cuda") + + prompt = "Spiderman is surfing" + generator = torch.Generator(device="cpu").manual_seed(0) + + video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="pt").frames + video = video_frames.cpu().numpy() + + assert np.abs(expected_video - video).mean() < 5e-2 From 4a5267a0df623cc3f0e09b69c919b05832b24894 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 22 Mar 2023 13:52:21 +0000 Subject: [PATCH 28/43] change file name --- docs/source/en/api/models.mdx | 4 ++-- .../models/{transformer_temp.py => transformer_temporal.py} | 0 src/diffusers/models/unet_3d_blocks.py | 2 +- src/diffusers/models/unet_3d_condition.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) rename src/diffusers/models/{transformer_temp.py => transformer_temporal.py} (100%) diff --git a/docs/source/en/api/models.mdx b/docs/source/en/api/models.mdx index 2a24e8cf95f1..a0fb3d61a65a 100644 --- a/docs/source/en/api/models.mdx +++ b/docs/source/en/api/models.mdx @@ -65,10 +65,10 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module [[autodoc]] models.transformer_2d.Transformer2DModelOutput ## TransformerTempModel -[[autodoc]] models.transformer_temp.TransformerTempModel +[[autodoc]] models.transformer_temporal.TransformerTempModel ## Transformer2DModelOutput -[[autodoc]] models.transformer_temp.TransformerTempModelOutput +[[autodoc]] models.transformer_temporal.TransformerTempModelOutput ## PriorTransformer [[autodoc]] models.prior_transformer.PriorTransformer diff --git a/src/diffusers/models/transformer_temp.py b/src/diffusers/models/transformer_temporal.py similarity index 100% rename from src/diffusers/models/transformer_temp.py rename to src/diffusers/models/transformer_temporal.py diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index 3ad101cb31cd..e83557bb7284 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -17,7 +17,7 @@ from .resnet import Downsample2D, ResnetBlock2D, Upsample2D from .transformer_2d import Transformer2DModel -from .transformer_temp import TransformerTempModel +from .transformer_temporal import TransformerTempModel def get_down_block( diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index ae32e11f9b2b..ad89e8fb3ea4 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -23,7 +23,7 @@ from ..utils import BaseOutput, logging from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin -from .transformer_temp import TransformerTempModel +from .transformer_temporal import TransformerTempModel from .unet_3d_blocks import ( CrossAttnDownBlock3D, CrossAttnUpBlock3D, From 7b3c48d994fc151cff8deb4c669414ff8304badb Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 22 Mar 2023 13:59:00 +0000 Subject: [PATCH 29/43] update --- src/diffusers/models/transformer_temporal.py | 63 ++------------------ src/diffusers/models/unet_3d_blocks.py | 7 --- src/diffusers/models/unet_3d_condition.py | 2 - 3 files changed, 5 insertions(+), 67 deletions(-) diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py index ab813a178723..93e6ea2527f4 100644 --- a/src/diffusers/models/transformer_temporal.py +++ b/src/diffusers/models/transformer_temporal.py @@ -18,7 +18,7 @@ from torch import nn from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, deprecate +from ..utils import BaseOutput from .attention import BasicTransformerBlock from .modeling_utils import ModelMixin @@ -27,9 +27,8 @@ class TransformerTempModelOutput(BaseOutput): """ Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`TransformerTempModel`] is discrete): - Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions - for the unnoised latent pixels. + sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`) + Hidden states conditioned on `encoder_hidden_states` input. """ sample: torch.FloatTensor @@ -61,16 +60,11 @@ class TransformerTempModel(ModelMixin, ConfigMixin): sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. Note that this is fixed at training time as it is used for learning a number of position embeddings. See `ImagePositionalEmbeddings`. - num_vector_embeds (`int`, *optional*): - Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. - Includes the class for the masked latent pixel. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. - The number of diffusion steps used during training. Note that this is fixed at training time as it is used - to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for - up to but not more than steps than `num_embeds_ada_norm`. attention_bias (`bool`, *optional*): Configure if the TransformerBlocks' attention should contain a bias parameter. + double_self_attention (`bool`, *optional*): + Configure if each TransformerBlock should contain two self-attention layers """ @register_to_config @@ -86,57 +80,15 @@ def __init__( cross_attention_dim: Optional[int] = None, attention_bias: bool = False, sample_size: Optional[int] = None, - num_vector_embeds: Optional[int] = None, - patch_size: Optional[int] = None, activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - use_linear_projection: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, - norm_type: str = "layer_norm", norm_elementwise_affine: bool = True, double_self_attention: bool = True, ): super().__init__() - self.use_linear_projection = use_linear_projection self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim - # 1. TransformerTempModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` - # Define whether input is continuous or discrete depending on configuration - self.is_input_continuous = (in_channels is not None) and (patch_size is None) - self.is_input_vectorized = num_vector_embeds is not None - self.is_input_patches = in_channels is not None and patch_size is not None - - if norm_type == "layer_norm" and num_embeds_ada_norm is not None: - deprecation_message = ( - f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" - " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." - " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" - " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" - " would be very nice if you could open a Pull request for the `transformer/config.json` file" - ) - deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) - norm_type = "ada_norm" - - if self.is_input_continuous and self.is_input_vectorized: - raise ValueError( - f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" - " sure that either `in_channels` or `num_vector_embeds` is None." - ) - elif self.is_input_vectorized and self.is_input_patches: - raise ValueError( - f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" - " sure that either `num_vector_embeds` or `num_patches` is None." - ) - elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: - raise ValueError( - f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" - f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." - ) - - # 2. Define input layers self.in_channels = in_channels self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) @@ -152,19 +104,14 @@ def __init__( dropout=dropout, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, - num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, - only_cross_attention=only_cross_attention, double_self_attention=double_self_attention, - upcast_attention=upcast_attention, - norm_type=norm_type, norm_elementwise_affine=norm_elementwise_affine, ) for d in range(num_layers) ] ) - # 4. Define output layers self.proj_out = nn.Linear(inner_dim, in_channels) def forward( diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index e83557bb7284..08750b2c27e8 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -203,8 +203,6 @@ def __init__( num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, ) ) resnets.append( @@ -336,8 +334,6 @@ def __init__( num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, ) ) self.resnets = nn.ModuleList(resnets) @@ -550,9 +546,6 @@ def __init__( num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, ) ) self.resnets = nn.ModuleList(resnets) diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index ad89e8fb3ea4..e174eda93f17 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -211,8 +211,6 @@ def __init__( attention_head_dim=attention_head_dim, in_channels=block_out_channels[0], num_layers=1, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, ) # class embedding From 48d05a4e72bf137058595eab3b36c36c3e7e9094 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 22 Mar 2023 14:43:14 +0000 Subject: [PATCH 30/43] more correction --- examples/community/stable_diffusion_controlnet_img2img.py | 2 +- examples/community/stable_diffusion_controlnet_inpaint.py | 2 +- .../stable_diffusion_controlnet_inpaint_img2img.py | 2 +- .../pipelines/alt_diffusion/pipeline_alt_diffusion.py | 2 +- .../alt_diffusion/pipeline_alt_diffusion_img2img.py | 2 +- .../pipelines/stable_diffusion/pipeline_cycle_diffusion.py | 2 +- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 2 +- .../pipeline_stable_diffusion_controlnet.py | 2 +- .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 2 +- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 2 +- .../pipeline_stable_diffusion_inpaint_legacy.py | 2 +- .../pipeline_stable_diffusion_instruct_pix2pix.py | 2 +- .../pipeline_stable_diffusion_k_diffusion.py | 2 +- .../pipeline_stable_diffusion_pix2pix_zero.py | 2 +- .../text_to_video_synthesis/pipeline_text_to_video_synth.py | 6 +++--- 15 files changed, 17 insertions(+), 17 deletions(-) diff --git a/examples/community/stable_diffusion_controlnet_img2img.py b/examples/community/stable_diffusion_controlnet_img2img.py index 51533a92d84a..ec23564ae3cb 100644 --- a/examples/community/stable_diffusion_controlnet_img2img.py +++ b/examples/community/stable_diffusion_controlnet_img2img.py @@ -216,7 +216,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") diff --git a/examples/community/stable_diffusion_controlnet_inpaint.py b/examples/community/stable_diffusion_controlnet_inpaint.py index 02e71fb97ed1..b7c8a2a7a7f0 100644 --- a/examples/community/stable_diffusion_controlnet_inpaint.py +++ b/examples/community/stable_diffusion_controlnet_inpaint.py @@ -314,7 +314,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") diff --git a/examples/community/stable_diffusion_controlnet_inpaint_img2img.py b/examples/community/stable_diffusion_controlnet_inpaint_img2img.py index a7afe26fa91c..f435a3274f45 100644 --- a/examples/community/stable_diffusion_controlnet_inpaint_img2img.py +++ b/examples/community/stable_diffusion_controlnet_inpaint_img2img.py @@ -314,7 +314,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index b94a2ec05649..1ae82beb54a4 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -234,7 +234,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 05138c86f246..b71217a4b3ec 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -244,7 +244,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index e977071b9c6c..76423867add1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -258,7 +258,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 5294fa4cfa06..81b2cfa9bc3e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -237,7 +237,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index fd82281005ad..aeb70b1b2234 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -274,7 +274,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 8b3a7944def1..835c88e19448 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -249,7 +249,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index b645ba667f77..cee7ace239db 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -293,7 +293,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index a770fb18aaa0..cb953a7803b2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -237,7 +237,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index 953df11aa4f7..06ab580d492f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -432,7 +432,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index f3db54caa342..2d40390b41d1 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -158,7 +158,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 7de12bd291fb..9c928129d0b9 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -394,7 +394,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index 301c28d36b56..8a51adc002ff 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -41,8 +41,8 @@ >>> from diffusers import TextToVideoMSPipeline >>> from diffusers.utils import export_to_video - >>> pipe = TextToVideoMSPipeline.from_pretrained("diffusers/ms-text-to-video-1.7b", torch_dtype=torch.float16) - >>> pipe = pipe.to("cuda") + >>> pipe = TextToVideoMSPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16") + >>> pipe.enable_model_cpu_offload() >>> prompt = "Spiderman is surfing" >>> video_frames = pipe(prompt).frames @@ -172,7 +172,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") From fb060aba76d709b4196700832f776fc73447dbb1 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 22 Mar 2023 14:59:57 +0000 Subject: [PATCH 31/43] more fixes --- src/diffusers/models/resnet.py | 64 +++++++++++++ src/diffusers/models/unet_3d_blocks.py | 74 +++------------ src/diffusers/models/unet_3d_condition.py | 90 ++----------------- .../pipeline_text_to_video_synth.py | 6 +- 4 files changed, 85 insertions(+), 149 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index d159115d7ee3..cc4882aa0961 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -1,3 +1,18 @@ +# Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. +# Copyright 2023 The ModelScope Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from functools import partial from typing import Optional @@ -764,3 +779,52 @@ def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)): out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 return out.view(-1, channel, out_h, out_w) + + +class TemporalConvLayer(nn.Module): + def __init__(self, in_dim, out_dim=None, dropout=0.0): + super().__init__() + out_dim = out_dim or in_dim + self.in_dim = in_dim + self.out_dim = out_dim + + # conv layers + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)) + ) + self.conv2 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), + ) + self.conv3 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), + ) + self.conv4 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), + ) + + # zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.conv4[-1].weight) + nn.init.zeros_(self.conv4[-1].bias) + + def forward(self, hidden_states, num_frames=1): + hidden_states = hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4) + + identity = hidden_states + hidden_states = self.conv1(hidden_states) + hidden_states = self.conv2(hidden_states) + hidden_states = self.conv3(hidden_states) + hidden_states = self.conv4(hidden_states) + + hidden_states = identity + hidden_states + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape((hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]) + return hidden_states diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index 08750b2c27e8..b59b4df5584b 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -15,7 +15,7 @@ import torch from torch import nn -from .resnet import Downsample2D, ResnetBlock2D, Upsample2D +from .resnet import Downsample2D, ResnetBlock2D, Upsample2D, TemporalConvLayer from .transformer_2d import Transformer2DModel from .transformer_temporal import TransformerTempModel @@ -34,7 +34,7 @@ def get_down_block( cross_attention_dim=None, downsample_padding=None, dual_cross_attention=False, - use_linear_projection=False, + use_linear_projection=True, only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", @@ -90,7 +90,7 @@ def get_up_block( resnet_groups=None, cross_attention_dim=None, dual_cross_attention=False, - use_linear_projection=False, + use_linear_projection=True, only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", @@ -148,7 +148,7 @@ def __init__( output_scale_factor=1.0, cross_attention_dim=1280, dual_cross_attention=False, - use_linear_projection=False, + use_linear_projection=True, upcast_attention=False, ): super().__init__() @@ -173,7 +173,7 @@ def __init__( ) ] temp_convs = [ - TemporalConvBlock_v2( + TemporalConvLayer( in_channels, in_channels, dropout=0.1, @@ -220,7 +220,7 @@ def __init__( ) ) temp_convs.append( - TemporalConvBlock_v2( + TemporalConvLayer( in_channels, in_channels, dropout=0.1, @@ -307,7 +307,7 @@ def __init__( ) ) temp_convs.append( - TemporalConvBlock_v2( + TemporalConvLayer( out_channels, out_channels, dropout=0.1, @@ -427,7 +427,7 @@ def __init__( ) ) temp_convs.append( - TemporalConvBlock_v2( + TemporalConvLayer( out_channels, out_channels, dropout=0.1, @@ -519,7 +519,7 @@ def __init__( ) ) temp_convs.append( - TemporalConvBlock_v2( + TemporalConvLayer( out_channels, out_channels, dropout=0.1, @@ -636,7 +636,7 @@ def __init__( ) ) temp_convs.append( - TemporalConvBlock_v2( + TemporalConvLayer( out_channels, out_channels, dropout=0.1, @@ -668,57 +668,3 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_si hidden_states = upsampler(hidden_states, upsample_size) return hidden_states - - -class TemporalConvBlock_v2(nn.Module): - def __init__(self, in_dim, out_dim=None, dropout=0.0, use_image_dataset=False): - super(TemporalConvBlock_v2, self).__init__() - if out_dim is None: - out_dim = in_dim # int(1.5*in_dim) - self.in_dim = in_dim - self.out_dim = out_dim - self.use_image_dataset = use_image_dataset - - # conv layers - self.conv1 = nn.Sequential( - nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)) - ) - self.conv2 = nn.Sequential( - nn.GroupNorm(32, out_dim), - nn.SiLU(), - nn.Dropout(dropout), - nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), - ) - self.conv3 = nn.Sequential( - nn.GroupNorm(32, out_dim), - nn.SiLU(), - nn.Dropout(dropout), - nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), - ) - self.conv4 = nn.Sequential( - nn.GroupNorm(32, out_dim), - nn.SiLU(), - nn.Dropout(dropout), - nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), - ) - - # zero out the last layer params,so the conv block is identity - nn.init.zeros_(self.conv4[-1].weight) - nn.init.zeros_(self.conv4[-1].bias) - - def forward(self, x, num_frames=1): - x = x[None, :].reshape((-1, num_frames) + x.shape[1:]).permute(0, 2, 1, 3, 4) - - identity = x - x = self.conv1(x) - x = self.conv2(x) - x = self.conv3(x) - x = self.conv4(x) - - if self.use_image_dataset: - x = identity + 0.0 * x - else: - x = identity + x - - x = x.permute(0, 2, 1, 3, 4).reshape((x.shape[0] * x.shape[2], -1) + x.shape[3:]) - return x diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index e174eda93f17..40b91600edef 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -21,7 +21,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging -from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps +from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .transformer_temporal import TransformerTempModel from .unet_3d_blocks import ( @@ -42,7 +42,7 @@ class UNet3DConditionOutput(BaseOutput): """ Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. """ @@ -62,20 +62,10 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. - center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. - flip_sin_to_cos (`bool`, *optional*, defaults to `False`): - Whether to flip the sin to cos in the time embedding. - freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. - mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): - The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`, will skip the - mid block layer if `None`. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): The tuple of upsample blocks to use. - only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): - Whether to include self-attention in the basic transformer blocks, see - [`~models.attention.BasicTransformerBlock`]. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. @@ -87,24 +77,6 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. - resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config - for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. - class_embed_type (`str`, *optional*, defaults to None): - The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, - `"timestep"`, `"identity"`, or `"projection"`. - num_class_embeds (`int`, *optional*, defaults to None): - Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing - class conditioning with `class_embed_type` equal to `None`. - time_embedding_type (`str`, *optional*, default to `positional`): - The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. - timestep_post_act (`str, *optional*, default to `None`): - The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. - time_cond_proj_dim (`int`, *optional*, default to `None`): - The dimension of `cond_proj` layer in timestep embedding. - conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. - conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. - projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when - using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`. """ _supports_gradient_checkpointing = True @@ -115,36 +87,22 @@ def __init__( sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, - center_input_sample: bool = False, # remove - flip_sin_to_cos: bool = True, # remove - freq_shift: int = 0, # remove down_block_types: Tuple[str] = ( "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D", ), - mid_block_type: Optional[str] = "UNetMidBlock3DCrossAttn", up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"), - only_cross_attention: Union[bool, Tuple[bool]] = False, # remove block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, - norm_eps: float = 1e-5, # remove + norm_eps: float = 1e-5, cross_attention_dim: int = 1024, attention_head_dim: Union[int, Tuple[int]] = 64, - use_linear_projection: bool = True, # remove - class_embed_type: Optional[str] = None, # remove - num_class_embeds: Optional[int] = None, # remove - upcast_attention: bool = False, # remvoe - resnet_time_scale_shift: str = "default", # remove - time_embedding_type: str = "positional", # remove - timestep_post_act: Optional[str] = None, # remove - time_cond_proj_dim: Optional[int] = None, # remove - projection_class_embeddings_input_dim: Optional[int] = None, # remove ): super().__init__() @@ -161,11 +119,6 @@ def __init__( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) - if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." - ) - if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." @@ -180,30 +133,14 @@ def __init__( ) # time - if time_embedding_type == "fourier": - time_embed_dim = block_out_channels[0] * 2 - if time_embed_dim % 2 != 0: - raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") - self.time_proj = GaussianFourierProjection( - time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos - ) - timestep_input_dim = time_embed_dim - elif time_embedding_type == "positional": - time_embed_dim = block_out_channels[0] * 4 - - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) - timestep_input_dim = block_out_channels[0] - else: - raise ValueError( - f"{time_embedding_type} does not exist. Pleaes make sure to use one of `fourier` or `positional`." - ) + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], True, 0) + timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, - post_act_fn=timestep_post_act, - cond_proj_dim=time_cond_proj_dim, ) self.transformer_in = TransformerTempModel( @@ -217,9 +154,6 @@ def __init__( self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) - if isinstance(only_cross_attention, bool): - only_cross_attention = [only_cross_attention] * len(down_block_types) - if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) @@ -244,10 +178,6 @@ def __init__( attn_num_head_channels=attention_head_dim[i], downsample_padding=downsample_padding, dual_cross_attention=False, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, ) self.down_blocks.append(down_block) @@ -258,13 +188,10 @@ def __init__( resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attention_head_dim[-1], resnet_groups=norm_num_groups, dual_cross_attention=False, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, ) # count how many layers upsample the images @@ -273,7 +200,6 @@ def __init__( # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_attention_head_dim = list(reversed(attention_head_dim)) - only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): @@ -304,10 +230,6 @@ def __init__( cross_attention_dim=cross_attention_dim, attn_num_head_channels=reversed_attention_head_dim[i], dual_cross_attention=False, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, ) self.up_blocks.append(up_block) prev_output_channel = output_channel diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index 8a51adc002ff..64361e6cc7da 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -1,4 +1,5 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. +# Copyright 2023 The ModelScope Team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -53,6 +54,7 @@ def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]: + # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 # reshape to ncfhw mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) @@ -489,6 +491,8 @@ def __call__( The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. + num_frames (`int`, *optional*, defaults to 16): + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds amounts to 2 seconds of video. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. From e47969d4157562356deb3ff46853bbc45e6ad25b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 22 Mar 2023 15:02:02 +0000 Subject: [PATCH 32/43] finish --- src/diffusers/models/resnet.py | 13 +++++++++++-- src/diffusers/models/unet_3d_blocks.py | 2 +- .../pipeline_text_to_video_synth.py | 7 +++++-- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index cc4882aa0961..f79d6918bcdf 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -782,6 +782,11 @@ def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)): class TemporalConvLayer(nn.Module): + """ + Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from: + https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016 + """ + def __init__(self, in_dim, out_dim=None, dropout=0.0): super().__init__() out_dim = out_dim or in_dim @@ -816,7 +821,9 @@ def __init__(self, in_dim, out_dim=None, dropout=0.0): nn.init.zeros_(self.conv4[-1].bias) def forward(self, hidden_states, num_frames=1): - hidden_states = hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4) + hidden_states = ( + hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4) + ) identity = hidden_states hidden_states = self.conv1(hidden_states) @@ -826,5 +833,7 @@ def forward(self, hidden_states, num_frames=1): hidden_states = identity + hidden_states - hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape((hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape( + (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:] + ) return hidden_states diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index b59b4df5584b..40bb6702f7a0 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -15,7 +15,7 @@ import torch from torch import nn -from .resnet import Downsample2D, ResnetBlock2D, Upsample2D, TemporalConvLayer +from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D from .transformer_2d import Transformer2DModel from .transformer_temporal import TransformerTempModel diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index 64361e6cc7da..8f285f9ed6d9 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -42,7 +42,9 @@ >>> from diffusers import TextToVideoMSPipeline >>> from diffusers.utils import export_to_video - >>> pipe = TextToVideoMSPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16") + >>> pipe = TextToVideoMSPipeline.from_pretrained( + ... "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16" + ... ) >>> pipe.enable_model_cpu_offload() >>> prompt = "Spiderman is surfing" @@ -492,7 +494,8 @@ def __call__( width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_frames (`int`, *optional*, defaults to 16): - The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds amounts to 2 seconds of video. + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + amounts to 2 seconds of video. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. From 436babeb7e79ef98444e5bf1aa98d835b4747657 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 22 Mar 2023 15:03:48 +0000 Subject: [PATCH 33/43] up --- docs/source/en/api/pipelines/text_to_video.mdx | 18 +++++++++--------- src/diffusers/__init__.py | 2 +- src/diffusers/pipelines/__init__.py | 2 +- .../text_to_video_synthesis/__init__.py | 4 ++-- .../pipeline_text_to_video_synth.py | 16 ++++++++-------- .../dummy_torch_and_transformers_objects.py | 2 +- .../text_to_video/test_text_to_video.py | 18 +++++++++--------- 7 files changed, 31 insertions(+), 31 deletions(-) diff --git a/docs/source/en/api/pipelines/text_to_video.mdx b/docs/source/en/api/pipelines/text_to_video.mdx index a54fda575c21..33f236080c98 100644 --- a/docs/source/en/api/pipelines/text_to_video.mdx +++ b/docs/source/en/api/pipelines/text_to_video.mdx @@ -28,7 +28,7 @@ Resources: | Pipeline | Tasks | Demo |---|---|:---:| -| [TextToVideoMSPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py) | *Text-to-Video Generation* | [Spaces] (TODO) +| [TextToVideoSDPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py) | *Text-to-Video Generation* | [Spaces] (TODO) ## Usage example @@ -36,10 +36,10 @@ Let's start by generating a short video: ```python import torch -from diffusers import TextToVideoMSPipeline +from diffusers import TextToVideoSDPipeline from diffusers.utils import export_to_video -pipe = TextToVideoMSPipeline.from_pretrained("diffusers/ms-text-to-video-1.7b", torch_dtype=torch.float16) +pipe = TextToVideoSDPipeline.from_pretrained("diffusers/ms-text-to-video-1.7b", torch_dtype=torch.float16) pipe = pipe.to("cuda") prompt = "Spiderman is surfing" @@ -56,10 +56,10 @@ Let's generate a video of 8 seconds with CPU offloading and VAE slicing: ```python import torch -from diffusers import TextToVideoMSPipeline +from diffusers import TextToVideoSDPipeline from diffusers.utils import export_to_video -pipe = TextToVideoMSPipeline.from_pretrained("diffusers/ms-text-to-video-1.7b", torch_dtype=torch.float16) +pipe = TextToVideoSDPipeline.from_pretrained("diffusers/ms-text-to-video-1.7b", torch_dtype=torch.float16) pipe = pipe.to("cuda") # memory optimization @@ -78,10 +78,10 @@ We can also use a different scheduler easily: ```python import torch -from diffusers import TextToVideoMSPipeline, DPMSolverMultistepScheduler +from diffusers import TextToVideoSDPipeline, DPMSolverMultistepScheduler from diffusers.utils import export_to_video -pipe = TextToVideoMSPipeline.from_pretrained("diffusers/ms-text-to-video-1.7b", torch_dtype=torch.float16) +pipe = TextToVideoSDPipeline.from_pretrained("diffusers/ms-text-to-video-1.7b", torch_dtype=torch.float16) pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) pipe = pipe.to("cuda") @@ -117,7 +117,7 @@ Here are sample outputs: * [diffusers/ms-text-to-video-sd](https://huggingface.co/diffusers/ms-text-to-video-sd/) * [diffusers/ms-text-to-video-1.7b](https://huggingface.co/diffusers/ms-text-to-video-1.7b) -## TextToVideoMSPipeline -[[autodoc]] TextToVideoMSPipeline +## TextToVideoSDPipeline +[[autodoc]] TextToVideoSDPipeline - all - __call__ \ No newline at end of file diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index eebdbb425718..a1e736671be7 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -131,7 +131,7 @@ StableDiffusionUpscalePipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, - TextToVideoMSPipeline, + TextToVideoSDPipeline, UnCLIPImageVariationPipeline, UnCLIPPipeline, VersatileDiffusionDualGuidedPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 1190ad7e1cf3..87d1a6997e59 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -65,7 +65,7 @@ StableUnCLIPPipeline, ) from .stable_diffusion_safe import StableDiffusionPipelineSafe - from .text_to_video_synthesis import TextToVideoMSPipeline + from .text_to_video_synthesis import TextToVideoSDPipeline from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline from .versatile_diffusion import ( VersatileDiffusionDualGuidedPipeline, diff --git a/src/diffusers/pipelines/text_to_video_synthesis/__init__.py b/src/diffusers/pipelines/text_to_video_synthesis/__init__.py index c1a75a38927e..bd7103837c4f 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/__init__.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/__init__.py @@ -7,7 +7,7 @@ @dataclass -class TextToVideoMSPipelineOutput(BaseOutput): +class TextToVideoSDPipelineOutput(BaseOutput): """ Output class for text to video pipelines. @@ -27,4 +27,4 @@ class TextToVideoMSPipelineOutput(BaseOutput): except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: - from .pipeline_text_to_video_synth import TextToVideoMSPipeline # noqa: F401 + from .pipeline_text_to_video_synth import TextToVideoSDPipeline # noqa: F401 diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index 8f285f9ed6d9..4d58424eae42 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -30,7 +30,7 @@ replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline -from . import TextToVideoMSPipelineOutput +from . import TextToVideoSDPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -39,10 +39,10 @@ Examples: ```py >>> import torch - >>> from diffusers import TextToVideoMSPipeline + >>> from diffusers import TextToVideoSDPipeline >>> from diffusers.utils import export_to_video - >>> pipe = TextToVideoMSPipeline.from_pretrained( + >>> pipe = TextToVideoSDPipeline.from_pretrained( ... "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16" ... ) >>> pipe.enable_model_cpu_offload() @@ -73,7 +73,7 @@ def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) - return images -class TextToVideoMSPipeline(DiffusionPipeline): +class TextToVideoSDPipeline(DiffusionPipeline): r""" Pipeline for text-to-video generation. @@ -529,7 +529,7 @@ def __call__( weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.TextToVideoMSPipelineOutput`] instead of a + Whether or not to return a [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be @@ -545,8 +545,8 @@ def __call__( Examples: Returns: - [`~pipelines.stable_diffusion.TextToVideoMSPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.TextToVideoMSPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated frames. """ # 0. Default height and width to unet @@ -656,4 +656,4 @@ def __call__( if not return_dict: return (video,) - return TextToVideoMSPipelineOutput(frames=video) + return TextToVideoSDPipelineOutput(frames=video) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index c36b5efd9ab3..5a28ce8cb04e 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -347,7 +347,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class TextToVideoMSPipeline(metaclass=DummyObject): +class TextToVideoSDPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): diff --git a/tests/pipelines/text_to_video/test_text_to_video.py b/tests/pipelines/text_to_video/test_text_to_video.py index 62686567e5c5..d8ace61cc5dc 100644 --- a/tests/pipelines/text_to_video/test_text_to_video.py +++ b/tests/pipelines/text_to_video/test_text_to_video.py @@ -24,7 +24,7 @@ AutoencoderKL, DDIMScheduler, DPMSolverMultistepScheduler, - TextToVideoMSPipeline, + TextToVideoSDPipeline, UNet3DConditionModel, ) from diffusers.utils import load_numpy, skip_mps, slow, torch_device @@ -36,8 +36,8 @@ torch.backends.cuda.matmul.allow_tf32 = False -class TextToVideoMSPipelineFastTests(PipelineTesterMixin, unittest.TestCase): - pipeline_class = TextToVideoMSPipeline +class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = TextToVideoSDPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS # No `output_type`. @@ -126,7 +126,7 @@ def get_dummy_inputs(self, device, seed=0): def test_text_to_video_default_case(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() - sd_pipe = TextToVideoMSPipeline(**components) + sd_pipe = TextToVideoSDPipeline(**components) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) @@ -142,7 +142,7 @@ def test_text_to_video_default_case(self): def test_stable_diffusion_pix2pix_negative_prompt(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() - sd_pipe = TextToVideoMSPipeline(**components) + sd_pipe = TextToVideoSDPipeline(**components) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) @@ -162,7 +162,7 @@ def test_stable_diffusion_pix2pix_dpm_multistep(self): components["scheduler"] = DPMSolverMultistepScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) - sd_pipe = TextToVideoMSPipeline(**components) + sd_pipe = TextToVideoSDPipeline(**components) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) @@ -315,13 +315,13 @@ def test_save_load_optional_components(self): @slow -class TextToVideoMSPipelineSlowTests(unittest.TestCase): +class TextToVideoSDPipelineSlowTests(unittest.TestCase): def test_full_model(self): expected_video = load_numpy( "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_to_video/video.npy" ) - pipe = TextToVideoMSPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b") + pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b") pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) pipe = pipe.to("cuda") @@ -338,7 +338,7 @@ def test_two_step_model(self): "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_to_video/video_2step.npy" ) - pipe = TextToVideoMSPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b") + pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b") pipe = pipe.to("cuda") prompt = "Spiderman is surfing" From 03275f55d6993d551150cfca30a712d9d03dc356 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 22 Mar 2023 16:19:22 +0100 Subject: [PATCH 34/43] Apply suggestions from code review --- src/diffusers/utils/__init__.py | 4 +--- src/diffusers/utils/testing_utils.py | 2 ++ 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index a43744db9bb8..f8d1c6f1280b 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -57,7 +57,6 @@ is_librosa_available, is_omegaconf_available, is_onnx_available, - is_opencv_available, is_safetensors_available, is_scipy_available, is_tensorboard_available, @@ -93,8 +92,7 @@ torch_device, ) -if is_opencv_available(): - from .testing_utils import export_to_video +from .testing_utils import export_to_video logger = get_logger(__name__) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 65ea371b85f6..467f36497233 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -263,6 +263,8 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str: if is_opencv_available(): import cv2 + else: + raise ImportError(BACKENDS_MAPPING["opencv"][1].format("export_to_video")) if output_video_path is None: output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name From e700c08a5be2e8c65c9ae654c0ad0a39eb2fb9ea Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 22 Mar 2023 15:19:44 +0000 Subject: [PATCH 35/43] up --- docs/source/en/api/pipelines/text_to_video.mdx | 10 +++++----- src/diffusers/models/attention.py | 4 ++++ src/diffusers/models/unet_3d_condition.py | 4 ++-- .../pipelines/text_to_video_synthesis/__init__.py | 11 ++++++----- .../pipeline_text_to_video_synth.py | 8 ++++++++ 5 files changed, 25 insertions(+), 12 deletions(-) diff --git a/docs/source/en/api/pipelines/text_to_video.mdx b/docs/source/en/api/pipelines/text_to_video.mdx index 33f236080c98..e0aeb895003f 100644 --- a/docs/source/en/api/pipelines/text_to_video.mdx +++ b/docs/source/en/api/pipelines/text_to_video.mdx @@ -39,7 +39,7 @@ import torch from diffusers import TextToVideoSDPipeline from diffusers.utils import export_to_video -pipe = TextToVideoSDPipeline.from_pretrained("diffusers/ms-text-to-video-1.7b", torch_dtype=torch.float16) +pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b-legacy", torch_dtype=torch.float16) pipe = pipe.to("cuda") prompt = "Spiderman is surfing" @@ -59,7 +59,7 @@ import torch from diffusers import TextToVideoSDPipeline from diffusers.utils import export_to_video -pipe = TextToVideoSDPipeline.from_pretrained("diffusers/ms-text-to-video-1.7b", torch_dtype=torch.float16) +pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b-legacy", torch_dtype=torch.float16) pipe = pipe.to("cuda") # memory optimization @@ -81,7 +81,7 @@ import torch from diffusers import TextToVideoSDPipeline, DPMSolverMultistepScheduler from diffusers.utils import export_to_video -pipe = TextToVideoSDPipeline.from_pretrained("diffusers/ms-text-to-video-1.7b", torch_dtype=torch.float16) +pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b-legacy", torch_dtype=torch.float16) pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) pipe = pipe.to("cuda") @@ -114,8 +114,8 @@ Here are sample outputs: ## Available checkpoints -* [diffusers/ms-text-to-video-sd](https://huggingface.co/diffusers/ms-text-to-video-sd/) -* [diffusers/ms-text-to-video-1.7b](https://huggingface.co/diffusers/ms-text-to-video-1.7b) +* [damo-vilab/text-to-video-ms-1.7b](https://huggingface.co/damo-vilab/text-to-video-ms-1.7b/) +* [damo-vilab/text-to-video-ms-1.7b-legacy](https://huggingface.co/damo-vilab/text-to-video-ms-1.7b-legacy) ## TextToVideoSDPipeline [[autodoc]] TextToVideoSDPipeline diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index eaeb68decdbd..f271e00f8639 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -184,6 +184,10 @@ class BasicTransformerBlock(nn.Module): attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm (: obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index 40b91600edef..81981766da7e 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -79,7 +79,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. """ - _supports_gradient_checkpointing = True + _supports_gradient_checkpointing = False @register_to_config def __init__( @@ -333,7 +333,7 @@ def forward( ) -> Union[UNet3DConditionOutput, Tuple]: r""" Args: - sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + sample (`torch.FloatTensor`): (batch, num_frames, channel, height, width) noisy inputs tensor timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states return_dict (`bool`, *optional*, defaults to `True`): diff --git a/src/diffusers/pipelines/text_to_video_synthesis/__init__.py b/src/diffusers/pipelines/text_to_video_synthesis/__init__.py index bd7103837c4f..c2437857a23a 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/__init__.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/__init__.py @@ -2,6 +2,7 @@ from typing import List, Optional, Union import numpy as np +import torch from ...utils import BaseOutput, OptionalDependencyNotAvailable, is_torch_available, is_transformers_available @@ -12,13 +13,13 @@ class TextToVideoSDPipelineOutput(BaseOutput): Output class for text to video pipelines. Args: - frames (`List[np.ndarray]`) - List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)`. - NumPy array present the denoised images of the diffusion pipeline. The length of the list denotes the video - length i.e., the number of frames. + frames (`List[np.ndarray]` or `torch.FloatTensor`) + List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as + a `torch` tensor. NumPy array present the denoised images of the diffusion pipeline. The length of the list + denotes the video length i.e., the number of frames. """ - frames: List[np.ndarray] + frames: Union[List[np.ndarray], torch.FloatTensor] try: diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index 4d58424eae42..0bf11f3dd495 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -113,6 +113,7 @@ def __init__( ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. @@ -122,6 +123,7 @@ def enable_vae_slicing(self): """ self.vae.enable_slicing() + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to @@ -129,6 +131,7 @@ def disable_vae_slicing(self): """ self.vae.disable_slicing() + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling def enable_vae_tiling(self): r""" Enable tiled VAE decoding. @@ -138,6 +141,7 @@ def enable_vae_tiling(self): """ self.vae.enable_tiling() + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to @@ -192,6 +196,7 @@ def enable_model_cpu_offload(self, gpu_id=0): self.final_offload_hook = hook @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling @@ -209,6 +214,7 @@ def _execution_device(self): return torch.device(module._hf_hook.execution_device) return self.device + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, @@ -370,6 +376,7 @@ def decode_latents(self, latents): video = video.float() return video + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. @@ -387,6 +394,7 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, From b7bebebde353456e7eb4a7d2f3568be64064b81b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 22 Mar 2023 15:21:40 +0000 Subject: [PATCH 36/43] finish --- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/testing_utils.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index f8d1c6f1280b..d803b053be71 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -94,6 +94,7 @@ from .testing_utils import export_to_video + logger = get_logger(__name__) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 467f36497233..7a3b8029f828 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -18,6 +18,7 @@ from packaging import version from .import_utils import ( + BACKENDS_MAPPING, is_compel_available, is_flax_available, is_onnx_available, From fc832f935e594ff6343199da884f4ea979989c00 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 22 Mar 2023 15:22:06 +0000 Subject: [PATCH 37/43] make copies --- .../text_to_video_synthesis/pipeline_text_to_video_synth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index 0bf11f3dd495..d784de25f461 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -229,7 +229,7 @@ def _encode_prompt( Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device From 975b02d055f2365687e69ef0b70ce704bda34ce2 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 22 Mar 2023 16:16:32 +0000 Subject: [PATCH 38/43] fix pipeline tests --- docs/source/en/api/pipelines/overview.mdx | 1 + docs/source/en/index.mdx | 3 +- .../text_to_video/test_text_to_video.py | 163 +----------------- tests/test_pipelines_common.py | 23 ++- 4 files changed, 23 insertions(+), 167 deletions(-) diff --git a/docs/source/en/api/pipelines/overview.mdx b/docs/source/en/api/pipelines/overview.mdx index 6d0a9a1159b2..3bf29888ae54 100644 --- a/docs/source/en/api/pipelines/overview.mdx +++ b/docs/source/en/api/pipelines/overview.mdx @@ -77,6 +77,7 @@ available a colab notebook to directly try them out. | [stable_unclip](./stable_unclip) | **Stable unCLIP** | Text-to-Image Generation | | [stable_unclip](./stable_unclip) | **Stable unCLIP** | Image-to-Image Text-Guided Generation | | [stochastic_karras_ve](./stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation | +| [text_to_video_sd](./api/pipelines/text_to_video) | [Modelscope's Text-to-video-synthesis Model in Open Domain](https://modelscope.cn/models/damo/text-to-video-synthesis/summary) | Text-to-Video Generation | | [unclip](./unclip) | [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://arxiv.org/abs/2204.06125) | Text-to-Image Generation | | [versatile_diffusion](./versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation | | [versatile_diffusion](./versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation | diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 59c4d595cc8b..2ccabb1b32ee 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -84,8 +84,9 @@ The library has three main components: | [stable_unclip](./stable_unclip) | Stable unCLIP | Text-to-Image Generation | | [stable_unclip](./stable_unclip) | Stable unCLIP | Image-to-Image Text-Guided Generation | | [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation | +| [text_to_video_sd](./api/pipelines/text_to_video) | [Modelscope's Text-to-video-synthesis Model in Open Domain](https://modelscope.cn/models/damo/text-to-video-synthesis/summary) | Text-to-Video Generation | | [unclip](./api/pipelines/unclip) | [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://arxiv.org/abs/2204.06125)(implementation by [kakaobrain](https://github.com/kakaobrain/karlo)) | Text-to-Image Generation | | [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation | | [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation | | [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation | -| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation | \ No newline at end of file +| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation | diff --git a/tests/pipelines/text_to_video/test_text_to_video.py b/tests/pipelines/text_to_video/test_text_to_video.py index d8ace61cc5dc..ba00433a977e 100644 --- a/tests/pipelines/text_to_video/test_text_to_video.py +++ b/tests/pipelines/text_to_video/test_text_to_video.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tempfile import unittest import numpy as np @@ -27,7 +26,7 @@ TextToVideoSDPipeline, UNet3DConditionModel, ) -from diffusers.utils import load_numpy, skip_mps, slow, torch_device +from diffusers.utils import load_numpy, skip_mps, slow from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS from ...test_pipelines_common import PipelineTesterMixin @@ -65,7 +64,6 @@ def get_dummy_components(self): up_block_types=("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"), cross_attention_dim=32, attention_head_dim=4, - use_linear_projection=True, ) scheduler = DDIMScheduler( beta_start=0.00085, @@ -120,6 +118,7 @@ def get_dummy_inputs(self, device, seed=0): "generator": generator, "num_inference_steps": 2, "guidance_scale": 6.0, + "output_type": "pt", } return inputs @@ -131,6 +130,7 @@ def test_text_to_video_default_case(self): sd_pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) + inputs["output_type"] = "np" frames = sd_pipe(**inputs).frames image_slice = frames[0][-3:, -3:, -1] @@ -139,41 +139,8 @@ def test_text_to_video_default_case(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - def test_stable_diffusion_pix2pix_negative_prompt(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - components = self.get_dummy_components() - sd_pipe = TextToVideoSDPipeline(**components) - sd_pipe = sd_pipe.to(device) - sd_pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(device) - negative_prompt = "french fries" - frames = sd_pipe(**inputs, negative_prompt=negative_prompt).frames - image_slice = frames[0][-3:, -3:, -1] - - assert frames[0].shape == (64, 64, 3) - expected_slice = np.array([166, 181, 167, 119, 99, 124, 110, 94, 114]) - - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - - def test_stable_diffusion_pix2pix_dpm_multistep(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - components = self.get_dummy_components() - components["scheduler"] = DPMSolverMultistepScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" - ) - sd_pipe = TextToVideoSDPipeline(**components) - sd_pipe = sd_pipe.to(device) - sd_pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(device) - frames = sd_pipe(**inputs).frames - image_slice = frames[0][-3:, -3:, -1] - - assert frames[0].shape == (64, 64, 3) - expected_slice = np.array([170, 190, 180, 140, 121, 136, 121, 97, 122]) - - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + def test_attention_slicing_forward_pass(self): + self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False) # (todo): sayakpaul @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.") @@ -189,130 +156,10 @@ def test_inference_batch_single_identical(self): def test_num_images_per_prompt(self): pass - # Overriding since the output type for this pipeline differs from that of - # text-to-image pipelines. - @skip_mps - def test_attention_slicing_forward_pass(self): - self._test_attention_slicing_forward_pass() - - def _test_attention_slicing_forward_pass(self, expected_max_diff=4e-3): - if not self.test_attention_slicing: - return - - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - # Warmup pass when using mps (see #372) - if torch_device == "mps": - _ = pipe(**self.get_dummy_inputs(torch_device)) - - inputs = self.get_dummy_inputs(torch_device) - output_without_slicing = pipe(**inputs).frames[0] - - pipe.enable_attention_slicing(slice_size=1) - inputs = self.get_dummy_inputs(torch_device) - output_with_slicing = pipe(**inputs).frames[0] - - max_diff = np.abs((output_with_slicing / 255.0) - (output_without_slicing / 255.0)).max() - self.assertLess(max_diff, expected_max_diff, "Attention slicing should not affect the inference results") - - avg_diff = np.abs(output_without_slicing - output_without_slicing).mean() - self.assertLess(avg_diff, 10, f"Error image deviates {avg_diff} pixels on average") - - # Overriding since the output type for this pipeline differs from that of - # text-to-image pipelines. - @skip_mps - def test_dict_tuple_outputs_equivalent(self): - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - # Warmup pass when using mps (see #372) - if torch_device == "mps": - _ = pipe(**self.get_dummy_inputs(torch_device)) - - output = pipe(**self.get_dummy_inputs(torch_device)).frames[0] - output_tuple = pipe(**self.get_dummy_inputs(torch_device), return_dict=False)[0][0] - - max_diff = np.abs(output / 255.0 - output_tuple / 255.0).max() - self.assertLess(max_diff, 1e-4) - @skip_mps def test_progress_bar(self): return super().test_progress_bar() - # Overriding since the output type for this pipeline differs from that of - # text-to-image pipelines. - @skip_mps - def test_save_load_local(self): - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - # Warmup pass when using mps (see #372) - if torch_device == "mps": - _ = pipe(**self.get_dummy_inputs(torch_device)) - - inputs = self.get_dummy_inputs(torch_device) - output = pipe(**inputs).frames[0] - - with tempfile.TemporaryDirectory() as tmpdir: - pipe.save_pretrained(tmpdir) - pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) - pipe_loaded.to(torch_device) - pipe_loaded.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(torch_device) - output_loaded = pipe_loaded(**inputs).frames[0] - - max_diff = np.abs((output / 255.0) - (output_loaded / 255.0)).max() - self.assertLess(max_diff, 1e-4) - - # Overriding since the output type for this pipeline differs from that of - # text-to-image pipelines. - @skip_mps - def test_save_load_optional_components(self): - if not hasattr(self.pipeline_class, "_optional_components"): - return - - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - # Warmup pass when using mps (see #372) - if torch_device == "mps": - _ = pipe(**self.get_dummy_inputs(torch_device)) - - # set all optional components to None - for optional_component in pipe._optional_components: - setattr(pipe, optional_component, None) - - inputs = self.get_dummy_inputs(torch_device) - output = pipe(**inputs).frames[0] - - with tempfile.TemporaryDirectory() as tmpdir: - pipe.save_pretrained(tmpdir) - pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) - pipe_loaded.to(torch_device) - pipe_loaded.set_progress_bar_config(disable=None) - - for optional_component in pipe._optional_components: - self.assertTrue( - getattr(pipe_loaded, optional_component) is None, - f"`{optional_component}` did not stay set to None after loading.", - ) - - inputs = self.get_dummy_inputs(torch_device) - output_loaded = pipe_loaded(**inputs).frames[0] - - max_diff = np.abs((output / 255.0) - (output_loaded / 255.0)).max() - self.assertLess(max_diff, 1e-4) - @slow class TextToVideoSDPipelineSlowTests(unittest.TestCase): diff --git a/tests/test_pipelines_common.py b/tests/test_pipelines_common.py index 1ab6baeb81a3..ac2abd716e42 100644 --- a/tests/test_pipelines_common.py +++ b/tests/test_pipelines_common.py @@ -20,6 +20,13 @@ torch.backends.cuda.matmul.allow_tf32 = False +def to_np(tensor): + if isinstance(tensor, torch.Tensor): + tensor = tensor.detach().cpu().numpy() + + return tensor + + @require_torch class PipelineTesterMixin: """ @@ -130,7 +137,7 @@ def test_save_load_local(self): inputs = self.get_dummy_inputs(torch_device) output_loaded = pipe_loaded(**inputs)[0] - max_diff = np.abs(output - output_loaded).max() + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() self.assertLess(max_diff, 1e-4) def test_pipeline_call_signature(self): @@ -327,7 +334,7 @@ def test_dict_tuple_outputs_equivalent(self): output = pipe(**self.get_dummy_inputs(torch_device))[0] output_tuple = pipe(**self.get_dummy_inputs(torch_device), return_dict=False)[0] - max_diff = np.abs(output - output_tuple).max() + max_diff = np.abs(to_np(output) - to_np(output_tuple)).max() self.assertLess(max_diff, 1e-4) def test_components_function(self): @@ -351,7 +358,7 @@ def test_float16_inference(self): output = pipe(**self.get_dummy_inputs(torch_device))[0] output_fp16 = pipe_fp16(**self.get_dummy_inputs(torch_device))[0] - max_diff = np.abs(output - output_fp16).max() + max_diff = np.abs(to_np(output) - to_np(output_fp16)).max() self.assertLess(max_diff, 1e-2, "The outputs of the fp16 and fp32 pipelines are too different.") @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") @@ -383,7 +390,7 @@ def test_save_load_float16(self): inputs = self.get_dummy_inputs(torch_device) output_loaded = pipe_loaded(**inputs)[0] - max_diff = np.abs(output - output_loaded).max() + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() self.assertLess(max_diff, 1e-2, "The output of the fp16 pipeline changed after saving and loading.") def test_save_load_optional_components(self): @@ -421,7 +428,7 @@ def test_save_load_optional_components(self): inputs = self.get_dummy_inputs(torch_device) output_loaded = pipe_loaded(**inputs)[0] - max_diff = np.abs(output - output_loaded).max() + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() self.assertLess(max_diff, 1e-4) @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices") @@ -442,7 +449,7 @@ def test_to_device(self): self.assertTrue(all(device == "cuda" for device in model_devices)) output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0] - self.assertTrue(np.isnan(output_cuda).sum() == 0) + self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0) def test_to_dtype(self): components = self.get_dummy_components() @@ -482,7 +489,7 @@ def _test_attention_slicing_forward_pass( output_with_slicing = pipe(**inputs)[0] if test_max_difference: - max_diff = np.abs(output_with_slicing - output_without_slicing).max() + max_diff = np.abs(to_np(output_with_slicing) - to_np(output_without_slicing)).max() self.assertLess(max_diff, expected_max_diff, "Attention slicing should not affect the inference results") if test_mean_pixel_difference: @@ -508,7 +515,7 @@ def test_cpu_offload_forward_pass(self): inputs = self.get_dummy_inputs(torch_device) output_with_offload = pipe(**inputs)[0] - max_diff = np.abs(output_with_offload - output_without_offload).max() + max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max() self.assertLess(max_diff, 1e-4, "CPU offloading should not affect the inference results") @unittest.skipIf( From 5b6be9b3d50aa8daf0578018db95e83f2eb66dba Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 22 Mar 2023 16:24:33 +0000 Subject: [PATCH 39/43] fix more tests --- .../source/en/api/pipelines/text_to_video.mdx | 29 ++-- tests/models/test_models_unet_3d_condition.py | 155 ------------------ 2 files changed, 14 insertions(+), 170 deletions(-) diff --git a/docs/source/en/api/pipelines/text_to_video.mdx b/docs/source/en/api/pipelines/text_to_video.mdx index e0aeb895003f..04fba5f74c42 100644 --- a/docs/source/en/api/pipelines/text_to_video.mdx +++ b/docs/source/en/api/pipelines/text_to_video.mdx @@ -28,7 +28,7 @@ Resources: | Pipeline | Tasks | Demo |---|---|:---:| -| [TextToVideoSDPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py) | *Text-to-Video Generation* | [Spaces] (TODO) +| [DiffusionPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py) | *Text-to-Video Generation* | [Spaces] (TODO) ## Usage example @@ -36,10 +36,10 @@ Let's start by generating a short video: ```python import torch -from diffusers import TextToVideoSDPipeline +from diffusers import DiffusionPipeline from diffusers.utils import export_to_video -pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b-legacy", torch_dtype=torch.float16) +pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16") pipe = pipe.to("cuda") prompt = "Spiderman is surfing" @@ -56,18 +56,17 @@ Let's generate a video of 8 seconds with CPU offloading and VAE slicing: ```python import torch -from diffusers import TextToVideoSDPipeline +from diffusers import DiffusionPipeline from diffusers.utils import export_to_video -pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b-legacy", torch_dtype=torch.float16) -pipe = pipe.to("cuda") +pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16") +pipe.enable_model_cpu_offload() # memory optimization -pipe.enable_model_cpu_offload() pipe.enable_vae_slicing() prompt = "Darth Vader surfing a wave" -video_frames = pipe(prompt, num_frames=64, num_inference_steps=25).frames +video_frames = pipe(prompt, num_frames=64).frames video_path = export_to_video(video_frames) video_path ``` @@ -78,15 +77,15 @@ We can also use a different scheduler easily: ```python import torch -from diffusers import TextToVideoSDPipeline, DPMSolverMultistepScheduler +from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler from diffusers.utils import export_to_video -pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b-legacy", torch_dtype=torch.float16) +pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16") pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) -pipe = pipe.to("cuda") +pipe.enable_model_cpu_offload() prompt = "Spiderman is surfing" -video_frames = pipe(prompt).frames +video_frames = pipe(prompt, num_inference_steps=25).frames video_path = export_to_video(video_frames) video_path ``` @@ -117,7 +116,7 @@ Here are sample outputs: * [damo-vilab/text-to-video-ms-1.7b](https://huggingface.co/damo-vilab/text-to-video-ms-1.7b/) * [damo-vilab/text-to-video-ms-1.7b-legacy](https://huggingface.co/damo-vilab/text-to-video-ms-1.7b-legacy) -## TextToVideoSDPipeline -[[autodoc]] TextToVideoSDPipeline +## DiffusionPipeline +[[autodoc]] DiffusionPipeline - all - - __call__ \ No newline at end of file + - __call__ diff --git a/tests/models/test_models_unet_3d_condition.py b/tests/models/test_models_unet_3d_condition.py index 794c1c5a0cab..a92b8edd5378 100644 --- a/tests/models/test_models_unet_3d_condition.py +++ b/tests/models/test_models_unet_3d_condition.py @@ -23,7 +23,6 @@ from diffusers.utils import ( floats_tensor, logging, - torch_all_close, torch_device, ) from diffusers.utils.import_utils import is_xformers_available @@ -120,47 +119,6 @@ def test_xformers_enable_works(self): == "XFormersAttnProcessor" ), "xformers is not enabled" - @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS") - def test_gradient_checkpointing(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - - assert not model.is_gradient_checkpointing and model.training - - out = model(**inputs_dict).sample - # run the backwards pass on the model. For backwards pass, for simplicity purpose, - # we won't calculate the loss and rather backprop on out.sum() - model.zero_grad() - - labels = torch.randn_like(out) - loss = (out - labels).mean() - loss.backward() - - # re-instantiate the model now enabling gradient checkpointing - model_2 = self.model_class(**init_dict) - # clone model - model_2.load_state_dict(model.state_dict()) - model_2.to(torch_device) - model_2.enable_gradient_checkpointing() - - assert model_2.is_gradient_checkpointing and model_2.training - - out_2 = model_2(**inputs_dict).sample - # run the backwards pass on the model. For backwards pass, for simplicity purpose, - # we won't calculate the loss and rather backprop on out.sum() - model_2.zero_grad() - loss_2 = (out_2 - labels).mean() - loss_2.backward() - - # compare the output and parameters gradients - self.assertTrue((loss - loss_2).abs() < 1e-4) - named_params = dict(model.named_parameters()) - named_params_2 = dict(model_2.named_parameters()) - for name, param in named_params.items(): - self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-4)) - # Overriding because `block_out_channels` needs to be different for this model. def test_forward_with_norm_groups(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -209,44 +167,6 @@ def test_determinism(self): max_diff = np.amax(np.abs(out_1 - out_2)) self.assertLessEqual(max_diff, 1e-5) - def test_model_with_attention_head_dim_tuple(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = (8, 16, 16, 16) - - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - with torch.no_grad(): - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output.sample - - self.assertIsNotNone(output) - expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - - def test_model_with_use_linear_projection(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["use_linear_projection"] = True - - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - with torch.no_grad(): - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output.sample - - self.assertIsNotNone(output) - expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - def test_model_attention_slicing(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -271,81 +191,6 @@ def test_model_attention_slicing(self): output = model(**inputs_dict) assert output is not None - def test_model_slicable_head_dim(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = (8, 16, 16, 16) - - model = self.model_class(**init_dict) - - def check_slicable_dim_attr(module: torch.nn.Module): - if hasattr(module, "set_attention_slice"): - assert isinstance(module.sliceable_head_dim, int) - - for child in module.children(): - check_slicable_dim_attr(child) - - # retrieve number of attention layers - for module in model.children(): - check_slicable_dim_attr(module) - - def test_special_attn_proc(self): - class AttnEasyProc(torch.nn.Module): - def __init__(self, num): - super().__init__() - self.weight = torch.nn.Parameter(torch.tensor(num)) - self.is_run = False - self.number = 0 - self.counter = 0 - - def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None): - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - query = attn.to_q(hidden_states) - - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - hidden_states += self.weight - - self.is_run = True - self.counter += 1 - self.number = number - - return hidden_states - - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = (8, 16, 16, 16) - - model = self.model_class(**init_dict) - model.to(torch_device) - - processor = AttnEasyProc(5.0) - - model.set_attn_processor(processor) - model(**inputs_dict, cross_attention_kwargs={"number": 123}).sample - - assert processor.counter == 12 - assert processor.is_run - assert processor.number == 123 - # (`attn_processors`) needs to be implemented in this model for this test. # def test_lora_processors(self): From 9d7cd2de6381736731f0c56fee9236b75b5a5cb5 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 22 Mar 2023 17:44:14 +0100 Subject: [PATCH 40/43] Apply suggestions from code review Co-authored-by: Pedro Cuenca --- docs/source/en/api/pipelines/text_to_video.mdx | 14 +++++++------- scripts/convert_ms_text_to_video_to_diffusers.py | 1 - src/diffusers/models/resnet.py | 4 ++-- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/docs/source/en/api/pipelines/text_to_video.mdx b/docs/source/en/api/pipelines/text_to_video.mdx index 04fba5f74c42..f1fe794e1537 100644 --- a/docs/source/en/api/pipelines/text_to_video.mdx +++ b/docs/source/en/api/pipelines/text_to_video.mdx @@ -32,7 +32,7 @@ Resources: ## Usage example -Let's start by generating a short video: +Let's start by generating a short video with the default length of 16 frames (2s at 8 fps): ```python import torch @@ -48,11 +48,11 @@ video_path = export_to_video(video_frames) video_path ``` -Diffusers supports different optimization techniques to for improving the latency +Diffusers supports different optimization techniques to improve the latency and memory footprint of a pipeline. Since videos are often more memory-heavy than images, -for this pipeline, we can enable CPU offloading and VAE slicing to keep the memory-footprint at bay. +we can enable CPU offloading and VAE slicing to keep the memory footprint at bay. -Let's generate a video of 8 seconds with CPU offloading and VAE slicing: +Let's generate a video of 8 seconds (64 frames) on the same GPU using CPU offloading and VAE slicing: ```python import torch @@ -71,9 +71,9 @@ video_path = export_to_video(video_frames) video_path ``` -Together with PyTorch 2.0, "fp16" as the precision and the above techniques, it just takes 7 GBs of GPU memory. +It just takes **7 GBs of GPU memory** to generate the 64 video frames using PyTorch 2.0, "fp16" precision and the techniques mentioned above. -We can also use a different scheduler easily: +We can also use a different scheduler easily, using the same method we'd use for Stable Diffusion: ```python import torch @@ -90,7 +90,7 @@ video_path = export_to_video(video_frames) video_path ``` -Here are sample outputs: +Here are some sample outputs: diff --git a/scripts/convert_ms_text_to_video_to_diffusers.py b/scripts/convert_ms_text_to_video_to_diffusers.py index 91e699932558..b99fede143b9 100644 --- a/scripts/convert_ms_text_to_video_to_diffusers.py +++ b/scripts/convert_ms_text_to_video_to_diffusers.py @@ -407,7 +407,6 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False parser.add_argument( "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." ) - # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml parser.add_argument( "--original_config_file", default=None, diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index f79d6918bcdf..98f8f19c896a 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -1,5 +1,5 @@ -# Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. -# Copyright 2023 The ModelScope Team. +# Copyright 2023 The HuggingFace Team. All rights reserved. +# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 522f3aee76b333b5fde2606c30bf588049f9e102 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 22 Mar 2023 17:55:19 +0100 Subject: [PATCH 41/43] apply suggestions --- docs/source/en/api/models.mdx | 6 +- .../convert_ms_text_to_video_to_diffusers.py | 97 ------------------- src/diffusers/models/autoencoder_kl.py | 2 + src/diffusers/models/transformer_temporal.py | 25 ++--- src/diffusers/models/unet_3d_blocks.py | 8 +- src/diffusers/models/unet_3d_condition.py | 4 +- .../pipeline_text_to_video_synth.py | 29 +++--- 7 files changed, 33 insertions(+), 138 deletions(-) diff --git a/docs/source/en/api/models.mdx b/docs/source/en/api/models.mdx index a0fb3d61a65a..572f8873ba12 100644 --- a/docs/source/en/api/models.mdx +++ b/docs/source/en/api/models.mdx @@ -64,11 +64,11 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module ## Transformer2DModelOutput [[autodoc]] models.transformer_2d.Transformer2DModelOutput -## TransformerTempModel -[[autodoc]] models.transformer_temporal.TransformerTempModel +## TransformerTemporalModel +[[autodoc]] models.transformer_temporal.TransformerTemporalModel ## Transformer2DModelOutput -[[autodoc]] models.transformer_temporal.TransformerTempModelOutput +[[autodoc]] models.transformer_temporal.TransformerTemporalModelOutput ## PriorTransformer [[autodoc]] models.prior_transformer.PriorTransformer diff --git a/scripts/convert_ms_text_to_video_to_diffusers.py b/scripts/convert_ms_text_to_video_to_diffusers.py index 91e699932558..3102c7eede9b 100644 --- a/scripts/convert_ms_text_to_video_to_diffusers.py +++ b/scripts/convert_ms_text_to_video_to_diffusers.py @@ -407,104 +407,7 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False parser.add_argument( "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." ) - # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml - parser.add_argument( - "--original_config_file", - default=None, - type=str, - help="The YAML config file corresponding to the original architecture.", - ) - parser.add_argument( - "--num_in_channels", - default=None, - type=int, - help="The number of input channels. If `None` number of input channels will be automatically inferred.", - ) - parser.add_argument( - "--scheduler_type", - default="pndm", - type=str, - help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']", - ) - parser.add_argument( - "--pipeline_type", - default=None, - type=str, - help=( - "The pipeline type. One of 'FrozenOpenCLIPEmbedder', 'FrozenCLIPEmbedder', 'PaintByExample'" - ". If `None` pipeline will be automatically inferred." - ), - ) - parser.add_argument( - "--image_size", - default=None, - type=int, - help=( - "The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2" - " Base. Use 768 for Stable Diffusion v2." - ), - ) - parser.add_argument( - "--prediction_type", - default=None, - type=str, - help=( - "The prediction type that the model was trained on. Use 'epsilon' for Stable Diffusion v1.X and Stable" - " Diffusion v2 Base. Use 'v_prediction' for Stable Diffusion v2." - ), - ) - parser.add_argument( - "--extract_ema", - action="store_true", - help=( - "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights" - " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield" - " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." - ), - ) - parser.add_argument( - "--upcast_attention", - action="store_true", - help=( - "Whether the attention computation should always be upcasted. This is necessary when running stable" - " diffusion 2.1." - ), - ) - parser.add_argument( - "--from_safetensors", - action="store_true", - help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.", - ) - parser.add_argument( - "--to_safetensors", - action="store_true", - help="Whether to store pipeline in safetensors format or not.", - ) parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") - parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") - parser.add_argument( - "--stable_unclip", - type=str, - default=None, - required=False, - help="Set if this is a stable unCLIP model. One of 'txt2img' or 'img2img'.", - ) - parser.add_argument( - "--stable_unclip_prior", - type=str, - default=None, - required=False, - help="Set if this is a stable unCLIP txt2img model. Selects which prior to use. If `--stable_unclip` is set to `txt2img`, the karlo prior (https://huggingface.co/kakaobrain/karlo-v1-alpha/tree/main/prior) is selected by default.", - ) - parser.add_argument( - "--clip_stats_path", - type=str, - help="Path to the clip stats file. Only required if the stable unclip model's config specifies `model.params.noise_aug_config.params.clip_stats_path`.", - required=False, - ) - parser.add_argument( - "--controlnet", action="store_true", default=None, help="Set flag if this is a controlnet checkpoint." - ) args = parser.parse_args() unet_checkpoint = torch.load(args.checkpoint_path, map_location="cpu") diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index 9c0161065e4c..8f65c2357cac 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -207,6 +207,7 @@ def blend_h(self, a, b, blend_extent): def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: r"""Encode a batch of images using a tiled encoder. + Args: When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is: @@ -253,6 +254,7 @@ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> Autoen def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: r"""Decode a batch of images using a tiled decoder. + Args: When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is: diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py index 93e6ea2527f4..ece88b8db2d5 100644 --- a/src/diffusers/models/transformer_temporal.py +++ b/src/diffusers/models/transformer_temporal.py @@ -24,7 +24,7 @@ @dataclass -class TransformerTempModelOutput(BaseOutput): +class TransformerTemporalModelOutput(BaseOutput): """ Args: sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`) @@ -34,20 +34,9 @@ class TransformerTempModelOutput(BaseOutput): sample: torch.FloatTensor -class TransformerTempModel(ModelMixin, ConfigMixin): +class TransformerTemporalModel(ModelMixin, ConfigMixin): """ - Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual - embeddings) inputs. - - When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard - transformer action. Finally, reshape to image. - - When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional - embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict - classes of unnoised image. - - Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised - image do not contain a prediction for the masked pixel as the unnoised image cannot be masked. + Transformer model for video-like data. Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. @@ -141,9 +130,9 @@ def forward( Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: - [`~models.transformer_2d.TransformerTempModelOutput`] or `tuple`: - [`~models.transformer_2d.TransformerTempModelOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. + [`~models.transformer_2d.TransformerTemporalModelOutput`] or `tuple`: + [`~models.transformer_2d.TransformerTemporalModelOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. """ # 1. Input batch_frames, channel, height, width = hidden_states.shape @@ -184,4 +173,4 @@ def forward( if not return_dict: return (output,) - return TransformerTempModelOutput(sample=output) + return TransformerTemporalModelOutput(sample=output) diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index 40bb6702f7a0..9f8ee2a22aab 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -17,7 +17,7 @@ from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D from .transformer_2d import Transformer2DModel -from .transformer_temporal import TransformerTempModel +from .transformer_temporal import TransformerTemporalModel def get_down_block( @@ -196,7 +196,7 @@ def __init__( ) ) temp_attentions.append( - TransformerTempModel( + TransformerTemporalModel( in_channels // attn_num_head_channels, attn_num_head_channels, in_channels=in_channels, @@ -327,7 +327,7 @@ def __init__( ) ) temp_attentions.append( - TransformerTempModel( + TransformerTemporalModel( out_channels // attn_num_head_channels, attn_num_head_channels, in_channels=out_channels, @@ -539,7 +539,7 @@ def __init__( ) ) temp_attentions.append( - TransformerTempModel( + TransformerTemporalModel( out_channels // attn_num_head_channels, attn_num_head_channels, in_channels=out_channels, diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index 81981766da7e..de762acabebf 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -23,7 +23,7 @@ from ..utils import BaseOutput, logging from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin -from .transformer_temporal import TransformerTempModel +from .transformer_temporal import TransformerTemporalModel from .unet_3d_blocks import ( CrossAttnDownBlock3D, CrossAttnUpBlock3D, @@ -143,7 +143,7 @@ def __init__( act_fn=act_fn, ) - self.transformer_in = TransformerTempModel( + self.transformer_in = TransformerTemporalModel( num_attention_heads=8, attention_head_dim=attention_head_dim, in_channels=block_out_channels[0], diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index d784de25f461..453809ef6df7 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -1,5 +1,4 @@ -# Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. -# Copyright 2023 The ModelScope Team. +# Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -478,7 +477,6 @@ def __call__( num_inference_steps: int = 50, guidance_scale: float = 9.0, negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, @@ -495,30 +493,28 @@ def __call__( Args: prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. + The height in pixels of the generated video. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. + The width in pixels of the generated video. num_frames (`int`, *optional*, defaults to 16): The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds amounts to 2 seconds of video. num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. + 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, + usually at the expense of lower video quality. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass + The prompt or prompts not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. @@ -526,9 +522,10 @@ def __call__( One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. + tensor will ge generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -536,6 +533,8 @@ def __call__( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generate video. Choose between `torch.FloatTensor` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] instead of a plain tuple. @@ -561,6 +560,8 @@ def __call__( height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor + num_images_per_prompt = 1 + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds From d4a11a35eb8291d989d07ca56d1724287a50402d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 22 Mar 2023 18:18:00 +0100 Subject: [PATCH 42/43] up --- tests/pipelines/text_to_video/test_text_to_video.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/pipelines/text_to_video/test_text_to_video.py b/tests/pipelines/text_to_video/test_text_to_video.py index ba00433a977e..eb43a360653a 100644 --- a/tests/pipelines/text_to_video/test_text_to_video.py +++ b/tests/pipelines/text_to_video/test_text_to_video.py @@ -43,7 +43,6 @@ class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase): required_optional_params = frozenset( [ "num_inference_steps", - "num_images_per_prompt", "generator", "latents", "return_dict", From 04ac57445401cb2ca7462a5fbdb6a8f7a7c24ab8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 22 Mar 2023 18:25:21 +0100 Subject: [PATCH 43/43] revert --- src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index f26e466ead34..ef4598433f82 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -200,8 +200,7 @@ def assign_to_checkpoint( # Global renaming happens here new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") - new_path = new_path.replace("middle_block.2", "mid_block.temp_attentions.0") - new_path = new_path.replace("middle_block.3", "mid_block.resnets.1") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") if additional_replacements is not None: for replacement in additional_replacements: