Skip to content

Commit ea301df

Browse files
committed
refactor
1 parent 834bfc6 commit ea301df

File tree

2 files changed

+105
-58
lines changed

2 files changed

+105
-58
lines changed

src/diffusers/models/transformers/transformer_wan_vace.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
175175
"""
176176

177177
_supports_gradient_checkpointing = True
178-
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
178+
_skip_layerwise_casting_patterns = ["patch_embedding", "vace_patch_embedding", "condition_embedder", "norm"]
179179
_no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock"]
180180
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
181181
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
@@ -273,9 +273,6 @@ def forward(
273273
return_dict: bool = True,
274274
attention_kwargs: Optional[Dict[str, Any]] = None,
275275
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
276-
if control_hidden_states is None:
277-
raise ValueError("Control hidden states must be provided for VACE models.")
278-
279276
if attention_kwargs is not None:
280277
attention_kwargs = attention_kwargs.copy()
281278
lora_scale = attention_kwargs.pop("scale", 1.0)
@@ -299,16 +296,24 @@ def forward(
299296

300297
if control_hidden_states_scale is None:
301298
control_hidden_states_scale = control_hidden_states.new_ones(len(self.config.vace_layers))
299+
control_hidden_states_scale = torch.unbind(control_hidden_states_scale)
300+
if len(control_hidden_states_scale) != len(self.config.vace_layers):
301+
raise ValueError(
302+
f"Length of `control_hidden_states_scale` {len(control_hidden_states_scale)} should be "
303+
f"equal to {len(self.config.vace_layers)}."
304+
)
302305

303306
# 1. Rotary position embedding
304307
rotary_emb = self.rope(hidden_states)
305308

306309
# 2. Patch embedding
307310
hidden_states = self.patch_embedding(hidden_states)
308311
hidden_states = hidden_states.flatten(2).transpose(1, 2)
312+
print("hidden_states", hidden_states.shape)
309313

310314
control_hidden_states = self.vace_patch_embedding(control_hidden_states)
311315
control_hidden_states = control_hidden_states.flatten(2).transpose(1, 2)
316+
print("control_hidden_states", control_hidden_states.shape)
312317
control_hidden_states_padding = control_hidden_states.new_zeros(
313318
batch_size, hidden_states.size(1) - control_hidden_states.size(1), control_hidden_states.size(2)
314319
)
@@ -329,36 +334,36 @@ def forward(
329334
# Prepare VACE hints
330335
control_hidden_states_list = []
331336
vace_hidden_states = hidden_states
332-
for block in self.vace_blocks:
337+
for i, block in enumerate(self.vace_blocks):
333338
vace_hidden_states, control_hidden_states = self._gradient_checkpointing_func(
334339
block, vace_hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb
335340
)
336-
control_hidden_states_list.append(control_hidden_states)
341+
control_hidden_states_list.append((control_hidden_states, control_hidden_states_scale[i]))
337342
control_hidden_states_list = control_hidden_states_list[::-1]
338343

339344
for i, block in enumerate(self.blocks):
340345
hidden_states = self._gradient_checkpointing_func(
341346
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
342347
)
343348
if i in self.config.vace_layers:
344-
control_hint = control_hidden_states_list.pop()
345-
hidden_states = hidden_states + control_hint * control_hidden_states_scale[i]
349+
control_hint, scale = control_hidden_states_list.pop()
350+
hidden_states = hidden_states + control_hint * scale
346351
else:
347352
# Prepare VACE hints
348353
control_hidden_states_list = []
349354
vace_hidden_states = hidden_states
350-
for block in self.vace_blocks:
355+
for i, block in enumerate(self.vace_blocks):
351356
vace_hidden_states, control_hidden_states = block(
352357
vace_hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb
353358
)
354-
control_hidden_states_list.append(control_hidden_states)
359+
control_hidden_states_list.append((control_hidden_states, control_hidden_states_scale[i]))
355360
control_hidden_states_list = control_hidden_states_list[::-1]
356361

357362
for i, block in enumerate(self.blocks):
358363
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
359364
if i in self.config.vace_layers:
360-
control_hint = control_hidden_states_list.pop()
361-
hidden_states = hidden_states + control_hint * control_hidden_states_scale[i]
365+
control_hint, scale = control_hidden_states_list.pop()
366+
hidden_states = hidden_states + control_hint * scale
362367

363368
# 6. Output norm, projection & unpatchify
364369
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)

src/diffusers/pipelines/wan/pipeline_wan_vace.py

Lines changed: 88 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def check_inputs(
292292
mask=None,
293293
reference_images=None,
294294
):
295-
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size
295+
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
296296
if height % base != 0 or width % base != 0:
297297
raise ValueError(f"`height` and `width` have to be divisible by {base} but are {height} and {width}.")
298298

@@ -368,55 +368,95 @@ def preprocess_conditions(
368368
device: Optional[torch.device] = None,
369369
):
370370
if video is not None:
371-
video = self.video_processor.preprocess_video(video, None, None) # Use the height/width of video
372-
image_size = tuple(video.shape[-2:])
371+
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
372+
video_height, video_width = self.video_processor.get_default_height_width(video[0])
373+
374+
if video_height * video_width > height * width:
375+
scale = min(width / video_width, height / video_height)
376+
video_height, video_width = int(video_height * scale), int(video_width * scale)
377+
378+
if video_height % base != 0 or video_width % base != 0:
379+
logger.warning(
380+
f"Video height and width should be divisible by {base}, but got {video_height} and {video_width}. "
381+
)
382+
video_height = (video_height // base) * base
383+
video_width = (video_width // base) * base
384+
385+
assert video_height * video_width <= height * width
386+
387+
video = self.video_processor.preprocess_video(video, video_height, video_width)
388+
image_size = (video_height, video_width) # Use the height/width of video (with possible rescaling)
373389
else:
374-
video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=dtype, device=device)
390+
video = torch.zeros(batch_size, 3, num_frames, height, width, dtype=dtype, device=device)
375391
image_size = (height, width) # Use the height/width provider by user
376392

377393
if mask is not None:
378-
mask = self.video_processor.preprocess_video(mask, height, width)
394+
mask = self.video_processor.preprocess_video(mask, image_size[0], image_size[1])
379395
else:
380-
mask = torch.ones_like(video, dtype=dtype, device=device)
396+
mask = torch.ones_like(video)
381397

382398
video = video.to(dtype=dtype, device=device)
383399
mask = mask.to(dtype=dtype, device=device)
384400

385-
reference_images_preprocessed = []
386-
if reference_images is not None:
387-
if not isinstance(reference_images, list):
388-
reference_images = [reference_images]
389-
for i, image in enumerate(reference_images):
390-
image = self.video_processor.preprocess(image, None, None) # Use the height/width of image
401+
# Make a list of list of images where the outer list corresponds to video batch size and the inner list
402+
# corresponds to list of conditioning images per video
403+
if reference_images is None or isinstance(reference_images, PIL.Image.Image):
404+
reference_images = [[reference_images] for _ in range(video.shape[0])]
405+
elif isinstance(reference_images, (list, tuple)) and isinstance(next(iter(reference_images)), PIL.Image.Image):
406+
reference_images = [reference_images]
407+
elif (
408+
isinstance(reference_images, (list, tuple))
409+
and isinstance(next(iter(reference_images)), list)
410+
and isinstance(next(iter(reference_images[0])), PIL.Image.Image)
411+
):
412+
reference_images = reference_images
413+
else:
414+
raise ValueError(
415+
"`reference_images` has to be of type `PIL.Image.Image` or `list` of `PIL.Image.Image`, or "
416+
"`list` of `list` of `PIL.Image.Image`, but is {type(reference_images)}"
417+
)
418+
419+
if video.shape[0] != len(reference_images):
420+
raise ValueError(
421+
f"Batch size of `video` {video.shape[0]} and length of `reference_images` {len(reference_images)} does not match."
422+
)
391423

424+
reference_images_preprocessed = []
425+
for i, reference_images_batch in enumerate(reference_images):
426+
preprocessed_images = []
427+
for j, image in enumerate(reference_images_batch):
428+
if image is None:
429+
continue
430+
image = self.video_processor.preprocess(image, None, None)
392431
img_height, img_width = image.shape[-2:]
393432
scale = min(image_size[0] / img_height, image_size[1] / img_width)
394433
new_height, new_width = int(img_height * scale), int(img_width * scale)
395434
resized_image = torch.nn.functional.interpolate(
396-
image.unsqueeze(1), size=(new_height, new_width), mode="bilinear", align_corners=False
397-
).squeeze(1)
398-
435+
image, size=(new_height, new_width), mode="bilinear", align_corners=False
436+
).squeeze(0) # [C, H, W]
399437
top = (image_size[0] - new_height) // 2
400438
left = (image_size[1] - new_width) // 2
401-
canvas = torch.ones(batch_size, 1, 3, *image_size, device=device, dtype=dtype)
402-
canvas[:, :, :, top : top + new_height, left : left + new_width] = resized_image
403-
reference_images_preprocessed.append(canvas)
439+
canvas = torch.ones(3, *image_size, device=device, dtype=dtype)
440+
canvas[:, top : top + new_height, left : left + new_width] = resized_image
441+
preprocessed_images.append(canvas)
442+
reference_images_preprocessed.append(preprocessed_images)
404443

405444
return video, mask, reference_images_preprocessed
406445

407446
def prepare_video_latents(
408447
self,
409448
video: torch.Tensor,
410449
mask: torch.Tensor,
411-
reference_images: Optional[List[torch.Tensor]] = None,
450+
reference_images: Optional[List[List[torch.Tensor]]] = None,
412451
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
413452
) -> torch.Tensor:
414453
if isinstance(generator, list):
415454
# TODO: support this
416455
raise ValueError("Passing a list of generators is not yet supported. This may be supported in the future.")
417456

418457
if reference_images is None:
419-
# For each batch of video, we set no reference image (as one or more can be passed by user)
458+
# For each batch of video, we set no re
459+
# ference image (as one or more can be passed by user)
420460
reference_images = [[None] for _ in range(video.shape[0])]
421461
else:
422462
if video.shape[0] != len(reference_images):
@@ -437,22 +477,24 @@ def prepare_video_latents(
437477
latents = retrieve_latents(self.vae.encode(video), generator, sample_mode="argmax").unbind(0)
438478
else:
439479
mask = mask.to(dtype=vae_dtype)
440-
mask = [torch.where(m > 0.5, 1.0, 0.0) for m in mask]
441-
inactive = [v * (1 - m) for v, m in zip(video, mask)]
442-
reactive = [v * m for v, m in zip(video, mask)]
480+
mask = torch.where(mask > 0.5, 1.0, 0.0)
481+
inactive = video * (1 - mask)
482+
reactive = video * mask
443483
inactive = retrieve_latents(self.vae.encode(inactive), generator, sample_mode="argmax")
444484
reactive = retrieve_latents(self.vae.encode(reactive), generator, sample_mode="argmax")
445-
latents = [torch.cat([i, r], dim=0) for i, r in zip(inactive, reactive)]
485+
latents = torch.cat([inactive, reactive], dim=1)
446486

447487
latent_list = []
448-
for latent, ref_images in zip(latents, reference_images):
449-
if ref_images is not None:
450-
ref_images = ref_images.to(dtype=vae_dtype)
451-
ref_latents = retrieve_latents(self.vae.encode(ref_images), generator, sample_mode="argmax")
452-
ref_latents = [torch.cat([r, torch.zeros_like(r)], dim=0) for r in ref_latents]
453-
latent = torch.cat([*ref_latents, latent], dim=1)
488+
for latent, reference_images_batch in zip(latents, reference_images):
489+
for reference_image in reference_images_batch:
490+
assert reference_image.ndim == 3
491+
reference_image = reference_image.to(dtype=vae_dtype)
492+
reference_image = reference_image[None, :, None, :, :] # [1, C, 1, H, W]
493+
reference_latent = retrieve_latents(self.vae.encode(reference_image), generator, sample_mode="argmax")
494+
reference_latent = torch.cat([reference_latent, torch.zeros_like(reference_latent)], dim=1)
495+
latent = torch.cat([reference_latent.squeeze(0), latent], dim=1) # Concat across frame dimension
454496
latent_list.append(latent)
455-
return latent_list
497+
return torch.stack(latent_list)
456498

457499
def prepare_masks(
458500
self,
@@ -479,25 +521,28 @@ def prepare_masks(
479521
"Generating with more than one video is not yet supported. This may be supported in the future."
480522
)
481523

524+
transformer_patch_size = self.transformer.config.patch_size[1]
525+
482526
mask_list = []
483-
transformer_patch_size = self.transformer.config.patch_size
484-
for mask_, ref_images in zip(mask, reference_images):
485-
num_frames, num_channels, height, width = mask_.shape
527+
for mask_, reference_images_batch in zip(mask, reference_images):
528+
num_channels, num_frames, height, width = mask_.shape
486529
new_num_frames = (num_frames + self.vae_scale_factor_temporal - 1) // self.vae_scale_factor_temporal
487530
new_height = height // (self.vae_scale_factor_spatial * transformer_patch_size) * transformer_patch_size
488531
new_width = width // (self.vae_scale_factor_spatial * transformer_patch_size) * transformer_patch_size
489-
mask_ = mask_[:, 0, :, :]
490-
mask_ = mask_.view(num_frames, height, self.vae_scale_factor_spatial, width, self.vae_scale_factor_spatial)
491-
mask_ = mask_.permute(2, 4, 0, 1, 3).flatten(2, 4).flatten(0, 1)
532+
mask_ = mask_[0, :, :, :]
533+
mask_ = mask_.view(
534+
num_frames, new_height, self.vae_scale_factor_spatial, new_width, self.vae_scale_factor_spatial
535+
)
536+
mask_ = mask_.permute(2, 4, 0, 1, 3).flatten(0, 1) # [8x8, num_frames, new_height, new_width]
492537
mask_ = torch.nn.functional.interpolate(
493538
mask_.unsqueeze(0), size=(new_num_frames, new_height, new_width), mode="nearest-exact"
494539
).squeeze(0)
495-
if ref_images is not None:
496-
num_ref_images = ref_images.size(0)
497-
mask_padding = torch.zeros_like(mask[:num_ref_images, :, :, :])
540+
num_ref_images = len(reference_images_batch)
541+
if num_ref_images > 0:
542+
mask_padding = torch.zeros_like(mask_[:, :num_ref_images, :, :])
498543
mask_ = torch.cat([mask_, mask_padding], dim=1)
499544
mask_list.append(mask_)
500-
return mask_list
545+
return torch.stack(mask_list)
501546

502547
def prepare_latents(
503548
self,
@@ -746,12 +791,9 @@ def __call__(
746791
)
747792

748793
conditioning_latents = self.prepare_video_latents(video, mask, reference_images, generator)
749-
conditioning_latents = [c.to(transformer_dtype) for c in conditioning_latents]
750-
751794
mask = self.prepare_masks(mask, reference_images, generator)
752-
mask = [m.to(transformer_dtype) for m in mask]
753-
754-
conditioning_latents = [torch.cat([c, m], dim=1) for c, m in zip(conditioning_latents, mask)]
795+
conditioning_latents = torch.cat([conditioning_latents, mask], dim=1)
796+
conditioning_latents = conditioning_latents.to(transformer_dtype)
755797

756798
num_channels_latents = self.transformer.config.in_channels
757799
latents = self.prepare_latents(

0 commit comments

Comments
 (0)