diff --git a/src/diffusers/models/unets/unet_i2vgen_xl.py b/src/diffusers/models/unets/unet_i2vgen_xl.py index 6b78968cb505..de4acb7e0d07 100644 --- a/src/diffusers/models/unets/unet_i2vgen_xl.py +++ b/src/diffusers/models/unets/unet_i2vgen_xl.py @@ -48,29 +48,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def _to_tensor(inputs, device): - if not torch.is_tensor(inputs): - # TODO: this requires sync between CPU and GPU. So try to pass `inputs` as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = device.type == "mps" - if isinstance(inputs, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - inputs = torch.tensor([inputs], dtype=dtype, device=device) - elif len(inputs.shape) == 0: - inputs = inputs[None].to(device) - - return inputs - - -def _collapse_frames_into_batch(sample: torch.Tensor) -> torch.Tensor: - batch_size, channels, num_frames, height, width = sample.shape - sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) - - return sample - - class I2VGenXLTransformerTemporalEncoder(nn.Module): def __init__( self, @@ -174,8 +151,6 @@ def __init__( ): super().__init__() - self.sample_size = sample_size - # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( @@ -543,7 +518,18 @@ def forward( forward_upsample_size = True # 1. time - timesteps = _to_tensor(timestep, sample.device) + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass `timesteps` as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timesteps, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) @@ -572,7 +558,13 @@ def forward( context_emb = sample.new_zeros(batch_size, 0, self.config.cross_attention_dim) context_emb = torch.cat([context_emb, encoder_hidden_states], dim=1) - image_latents_context_embs = _collapse_frames_into_batch(image_latents[:, :, :1, :]) + image_latents_for_context_embds = image_latents[:, :, :1, :] + image_latents_context_embs = image_latents_for_context_embds.permute(0, 2, 1, 3, 4).reshape( + image_latents_for_context_embds.shape[0] * image_latents_for_context_embds.shape[2], + image_latents_for_context_embds.shape[1], + image_latents_for_context_embds.shape[3], + image_latents_for_context_embds.shape[4], + ) image_latents_context_embs = self.image_latents_context_embedding(image_latents_context_embs) _batch_size, _channels, _height, _width = image_latents_context_embs.shape @@ -586,7 +578,12 @@ def forward( context_emb = torch.cat([context_emb, image_emb], dim=1) context_emb = context_emb.repeat_interleave(repeats=num_frames, dim=0) - image_latents = _collapse_frames_into_batch(image_latents) + image_latents = image_latents.permute(0, 2, 1, 3, 4).reshape( + image_latents.shape[0] * image_latents.shape[2], + image_latents.shape[1], + image_latents.shape[3], + image_latents.shape[4], + ) image_latents = self.image_latents_proj_in(image_latents) image_latents = ( image_latents[None, :] diff --git a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py index 57a1449d8634..5988957cb10f 100644 --- a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +++ b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py @@ -22,18 +22,13 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import LoraLoaderMixin from ...models import AutoencoderKL -from ...models.lora import adjust_lora_scale_text_encoder from ...models.unets.unet_i2vgen_xl import I2VGenXLUNet from ...schedulers import DDIMScheduler from ...utils import ( - USE_PEFT_BACKEND, BaseOutput, logging, replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -207,7 +202,6 @@ def encode_prompt( negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" @@ -233,23 +227,10 @@ def encode_prompt( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. - lora_scale (`float`, *optional*): - A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, LoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -380,10 +361,6 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - return prompt_embeds, negative_prompt_embeds def _encode_image(self, image, device, num_videos_per_prompt): @@ -706,9 +683,6 @@ def __call__( self._guidance_scale = guidance_scale # 3.1 Encode input text prompt - text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None - ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, @@ -716,7 +690,6 @@ def __call__( negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, clip_skip=clip_skip, ) # For classifier free guidance, we need to do two forward passes.