From 78d0fa0dc823e65c0e4c42bc3b06adb53c407ca7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 3 Feb 2024 20:04:53 +0530 Subject: [PATCH 1/7] remove _to_tensor --- src/diffusers/models/unets/unet_i2vgen_xl.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/unets/unet_i2vgen_xl.py b/src/diffusers/models/unets/unet_i2vgen_xl.py index 6b78968cb505..921172d30054 100644 --- a/src/diffusers/models/unets/unet_i2vgen_xl.py +++ b/src/diffusers/models/unets/unet_i2vgen_xl.py @@ -492,7 +492,7 @@ def disable_freeu(self): def forward( self, sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], + timestep: torch.Tensor, fps: torch.Tensor, image_latents: torch.Tensor, image_embeddings: Optional[torch.Tensor] = None, @@ -507,7 +507,7 @@ def forward( Args: sample (`torch.FloatTensor`): The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`. - timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + timestep (`torch.Tensor`): The number of timesteps to denoise an input. fps (`torch.Tensor`): Frames per second for the video being generated. Used as a "micro-condition". image_latents (`torch.FloatTensor`): Image encodings from the VAE. image_embeddings (`torch.FloatTensor`): Projection embeddings of the conditioning image computed with a vision encoder. @@ -543,8 +543,19 @@ def forward( forward_upsample_size = True # 1. time - timesteps = _to_tensor(timestep, sample.device) - + timesteps = timestep + if not torch.is_tensor(inputs): + # TODO: this requires sync between CPU and GPU. So try to pass `inputs` as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(inputs, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + inputs = torch.tensor([inputs], dtype=dtype, device=sample.device) + elif len(inputs.shape) == 0: + inputs = inputs[None].to(sample.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) From f1cca4de18f2faa4d3b722d5478c153b7735b6db Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 4 Feb 2024 17:40:54 +0530 Subject: [PATCH 2/7] remove _to_tensor definition --- src/diffusers/models/unets/unet_i2vgen_xl.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/src/diffusers/models/unets/unet_i2vgen_xl.py b/src/diffusers/models/unets/unet_i2vgen_xl.py index 921172d30054..d45be91cf3d7 100644 --- a/src/diffusers/models/unets/unet_i2vgen_xl.py +++ b/src/diffusers/models/unets/unet_i2vgen_xl.py @@ -48,22 +48,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def _to_tensor(inputs, device): - if not torch.is_tensor(inputs): - # TODO: this requires sync between CPU and GPU. So try to pass `inputs` as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = device.type == "mps" - if isinstance(inputs, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - inputs = torch.tensor([inputs], dtype=dtype, device=device) - elif len(inputs.shape) == 0: - inputs = inputs[None].to(device) - - return inputs - - def _collapse_frames_into_batch(sample: torch.Tensor) -> torch.Tensor: batch_size, channels, num_frames, height, width = sample.shape sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) From b20796680d8e9689e3aae271508b3518efac7436 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 4 Feb 2024 18:27:33 +0530 Subject: [PATCH 3/7] remove _collapse_frames_into_batch --- src/diffusers/models/unets/unet_i2vgen_xl.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/unets/unet_i2vgen_xl.py b/src/diffusers/models/unets/unet_i2vgen_xl.py index d45be91cf3d7..02dc8cb2eb1d 100644 --- a/src/diffusers/models/unets/unet_i2vgen_xl.py +++ b/src/diffusers/models/unets/unet_i2vgen_xl.py @@ -48,13 +48,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def _collapse_frames_into_batch(sample: torch.Tensor) -> torch.Tensor: - batch_size, channels, num_frames, height, width = sample.shape - sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) - - return sample - - class I2VGenXLTransformerTemporalEncoder(nn.Module): def __init__( self, @@ -567,7 +560,8 @@ def forward( context_emb = sample.new_zeros(batch_size, 0, self.config.cross_attention_dim) context_emb = torch.cat([context_emb, encoder_hidden_states], dim=1) - image_latents_context_embs = _collapse_frames_into_batch(image_latents[:, :, :1, :]) + image_latents_for_context_embds = image_latents[:, :, :1, :] + image_latents_context_embs = image_latents_for_context_embds.permute(0, 2, 1, 3, 4).reshape(image_latents_for_context_embds.shape[0] * image_latents_for_context_embds.shape[2], image_latents_for_context_embds.shape[1], image_latents_for_context_embds.shape[3], image_latents_for_context_embds.shape[4]) image_latents_context_embs = self.image_latents_context_embedding(image_latents_context_embs) _batch_size, _channels, _height, _width = image_latents_context_embs.shape @@ -581,7 +575,7 @@ def forward( context_emb = torch.cat([context_emb, image_emb], dim=1) context_emb = context_emb.repeat_interleave(repeats=num_frames, dim=0) - image_latents = _collapse_frames_into_batch(image_latents) + image_latents = image_latents.permute(0, 2, 1, 3, 4).reshape(image_latents.shape[0] * image_latents.shape[2], image_latents.shape[1], image_latents.shape[3], image_latents.shape[4]) image_latents = self.image_latents_proj_in(image_latents) image_latents = ( image_latents[None, :] From a164561a74fde24fd282c3980cf6ca1b6f6ef716 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 4 Feb 2024 18:31:29 +0530 Subject: [PATCH 4/7] remove lora for not bloating the code. --- .../pipelines/i2vgen_xl/pipeline_i2vgen_xl.py | 24 ------------------- 1 file changed, 24 deletions(-) diff --git a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py index ec6b328abdda..79c71ea82e2e 100644 --- a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +++ b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py @@ -22,9 +22,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import LoraLoaderMixin from ...models import AutoencoderKL -from ...models.lora import adjust_lora_scale_text_encoder from ...models.unets.unet_i2vgen_xl import I2VGenXLUNet from ...schedulers import DDIMScheduler from ...utils import ( @@ -205,7 +203,6 @@ def encode_prompt( negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" @@ -231,23 +228,10 @@ def encode_prompt( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. - lora_scale (`float`, *optional*): - A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, LoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -378,10 +362,6 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - return prompt_embeds, negative_prompt_embeds def _encode_image(self, image, device, num_videos_per_prompt): @@ -704,9 +684,6 @@ def __call__( self._guidance_scale = guidance_scale # 3.1 Encode input text prompt - text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None - ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, @@ -714,7 +691,6 @@ def __call__( negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, clip_skip=clip_skip, ) # For classifier free guidance, we need to do two forward passes. From 3147f7a001f8c05c6305300eb65efea6099a5c29 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 4 Feb 2024 18:43:51 +0530 Subject: [PATCH 5/7] remove sample_size. --- src/diffusers/models/unets/unet_i2vgen_xl.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/models/unets/unet_i2vgen_xl.py b/src/diffusers/models/unets/unet_i2vgen_xl.py index 02dc8cb2eb1d..69e26526fa4a 100644 --- a/src/diffusers/models/unets/unet_i2vgen_xl.py +++ b/src/diffusers/models/unets/unet_i2vgen_xl.py @@ -151,8 +151,6 @@ def __init__( ): super().__init__() - self.sample_size = sample_size - # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( From dcef3adf21f28168274fb9d2a1797e4fe06e537d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 4 Feb 2024 22:20:58 +0530 Subject: [PATCH 6/7] simplify code a bit more --- src/diffusers/models/unets/unet_i2vgen_xl.py | 28 +++++++++---------- .../pipelines/i2vgen_xl/pipeline_i2vgen_xl.py | 3 -- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/src/diffusers/models/unets/unet_i2vgen_xl.py b/src/diffusers/models/unets/unet_i2vgen_xl.py index 69e26526fa4a..74aff9594bdd 100644 --- a/src/diffusers/models/unets/unet_i2vgen_xl.py +++ b/src/diffusers/models/unets/unet_i2vgen_xl.py @@ -519,18 +519,8 @@ def forward( # 1. time timesteps = timestep - if not torch.is_tensor(inputs): - # TODO: this requires sync between CPU and GPU. So try to pass `inputs` as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(inputs, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - inputs = torch.tensor([inputs], dtype=dtype, device=sample.device) - elif len(inputs.shape) == 0: - inputs = inputs[None].to(sample.device) - + timesteps = timesteps.to(sample.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) @@ -559,7 +549,12 @@ def forward( context_emb = torch.cat([context_emb, encoder_hidden_states], dim=1) image_latents_for_context_embds = image_latents[:, :, :1, :] - image_latents_context_embs = image_latents_for_context_embds.permute(0, 2, 1, 3, 4).reshape(image_latents_for_context_embds.shape[0] * image_latents_for_context_embds.shape[2], image_latents_for_context_embds.shape[1], image_latents_for_context_embds.shape[3], image_latents_for_context_embds.shape[4]) + image_latents_context_embs = image_latents_for_context_embds.permute(0, 2, 1, 3, 4).reshape( + image_latents_for_context_embds.shape[0] * image_latents_for_context_embds.shape[2], + image_latents_for_context_embds.shape[1], + image_latents_for_context_embds.shape[3], + image_latents_for_context_embds.shape[4], + ) image_latents_context_embs = self.image_latents_context_embedding(image_latents_context_embs) _batch_size, _channels, _height, _width = image_latents_context_embs.shape @@ -573,7 +568,12 @@ def forward( context_emb = torch.cat([context_emb, image_emb], dim=1) context_emb = context_emb.repeat_interleave(repeats=num_frames, dim=0) - image_latents = image_latents.permute(0, 2, 1, 3, 4).reshape(image_latents.shape[0] * image_latents.shape[2], image_latents.shape[1], image_latents.shape[3], image_latents.shape[4]) + image_latents = image_latents.permute(0, 2, 1, 3, 4).reshape( + image_latents.shape[0] * image_latents.shape[2], + image_latents.shape[1], + image_latents.shape[3], + image_latents.shape[4], + ) image_latents = self.image_latents_proj_in(image_latents) image_latents = ( image_latents[None, :] diff --git a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py index 79c71ea82e2e..a4549f027ce4 100644 --- a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +++ b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py @@ -26,12 +26,9 @@ from ...models.unets.unet_i2vgen_xl import I2VGenXLUNet from ...schedulers import DDIMScheduler from ...utils import ( - USE_PEFT_BACKEND, BaseOutput, logging, replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline From 3d26814f9ee7c085f2b19240af7e92856858fc37 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 5 Feb 2024 08:21:45 +0530 Subject: [PATCH 7/7] ensure timesteps are always in tensor. --- src/diffusers/models/unets/unet_i2vgen_xl.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/unets/unet_i2vgen_xl.py b/src/diffusers/models/unets/unet_i2vgen_xl.py index 74aff9594bdd..de4acb7e0d07 100644 --- a/src/diffusers/models/unets/unet_i2vgen_xl.py +++ b/src/diffusers/models/unets/unet_i2vgen_xl.py @@ -467,7 +467,7 @@ def disable_freeu(self): def forward( self, sample: torch.FloatTensor, - timestep: torch.Tensor, + timestep: Union[torch.Tensor, float, int], fps: torch.Tensor, image_latents: torch.Tensor, image_embeddings: Optional[torch.Tensor] = None, @@ -482,7 +482,7 @@ def forward( Args: sample (`torch.FloatTensor`): The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`. - timestep (`torch.Tensor`): The number of timesteps to denoise an input. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. fps (`torch.Tensor`): Frames per second for the video being generated. Used as a "micro-condition". image_latents (`torch.FloatTensor`): Image encodings from the VAE. image_embeddings (`torch.FloatTensor`): Projection embeddings of the conditioning image computed with a vision encoder. @@ -519,7 +519,17 @@ def forward( # 1. time timesteps = timestep - timesteps = timesteps.to(sample.device) + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass `timesteps` as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timesteps, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0])