Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 25 additions & 28 deletions src/diffusers/models/unets/unet_i2vgen_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,29 +48,6 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


def _to_tensor(inputs, device):
if not torch.is_tensor(inputs):
# TODO: this requires sync between CPU and GPU. So try to pass `inputs` as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = device.type == "mps"
if isinstance(inputs, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
inputs = torch.tensor([inputs], dtype=dtype, device=device)
elif len(inputs.shape) == 0:
inputs = inputs[None].to(device)

return inputs


def _collapse_frames_into_batch(sample: torch.Tensor) -> torch.Tensor:
batch_size, channels, num_frames, height, width = sample.shape
sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)

return sample


class I2VGenXLTransformerTemporalEncoder(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -174,8 +151,6 @@ def __init__(
):
super().__init__()

self.sample_size = sample_size

# Check inputs
if len(down_block_types) != len(up_block_types):
raise ValueError(
Expand Down Expand Up @@ -543,7 +518,18 @@ def forward(
forward_upsample_size = True

# 1. time
timesteps = _to_tensor(timestep, sample.device)
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass `timesteps` as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timesteps, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)

# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
Expand Down Expand Up @@ -572,7 +558,13 @@ def forward(
context_emb = sample.new_zeros(batch_size, 0, self.config.cross_attention_dim)
context_emb = torch.cat([context_emb, encoder_hidden_states], dim=1)

image_latents_context_embs = _collapse_frames_into_batch(image_latents[:, :, :1, :])
image_latents_for_context_embds = image_latents[:, :, :1, :]
image_latents_context_embs = image_latents_for_context_embds.permute(0, 2, 1, 3, 4).reshape(
image_latents_for_context_embds.shape[0] * image_latents_for_context_embds.shape[2],
image_latents_for_context_embds.shape[1],
image_latents_for_context_embds.shape[3],
image_latents_for_context_embds.shape[4],
)
image_latents_context_embs = self.image_latents_context_embedding(image_latents_context_embs)

_batch_size, _channels, _height, _width = image_latents_context_embs.shape
Expand All @@ -586,7 +578,12 @@ def forward(
context_emb = torch.cat([context_emb, image_emb], dim=1)
context_emb = context_emb.repeat_interleave(repeats=num_frames, dim=0)

image_latents = _collapse_frames_into_batch(image_latents)
image_latents = image_latents.permute(0, 2, 1, 3, 4).reshape(
image_latents.shape[0] * image_latents.shape[2],
image_latents.shape[1],
image_latents.shape[3],
image_latents.shape[4],
)
image_latents = self.image_latents_proj_in(image_latents)
image_latents = (
image_latents[None, :]
Expand Down
27 changes: 0 additions & 27 deletions src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,13 @@
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin
from ...models import AutoencoderKL
from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unets.unet_i2vgen_xl import I2VGenXLUNet
from ...schedulers import DDIMScheduler
from ...utils import (
USE_PEFT_BACKEND,
BaseOutput,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
Expand Down Expand Up @@ -207,7 +202,6 @@ def encode_prompt(
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
):
r"""
Expand All @@ -233,23 +227,10 @@ def encode_prompt(
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
lora_scale (`float`, *optional*):
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
"""
# set lora scale so that monkey patched LoRA
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice clean up:)

# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
Expand Down Expand Up @@ -380,10 +361,6 @@ def encode_prompt(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)

if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, negative_prompt_embeds

def _encode_image(self, image, device, num_videos_per_prompt):
Expand Down Expand Up @@ -706,17 +683,13 @@ def __call__(
self._guidance_scale = guidance_scale

# 3.1 Encode input text prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
Expand Down