@@ -292,7 +292,7 @@ def check_inputs(
292292 mask = None ,
293293 reference_images = None ,
294294 ):
295- base = self .vae_scale_factor_spatial * self .transformer .config .patch_size
295+ base = self .vae_scale_factor_spatial * self .transformer .config .patch_size [ 1 ]
296296 if height % base != 0 or width % base != 0 :
297297 raise ValueError (f"`height` and `width` have to be divisible by { base } but are { height } and { width } ." )
298298
@@ -368,55 +368,95 @@ def preprocess_conditions(
368368 device : Optional [torch .device ] = None ,
369369 ):
370370 if video is not None :
371- video = self .video_processor .preprocess_video (video , None , None ) # Use the height/width of video
372- image_size = tuple (video .shape [- 2 :])
371+ base = self .vae_scale_factor_spatial * self .transformer .config .patch_size [1 ]
372+ video_height , video_width = self .video_processor .get_default_height_width (video [0 ])
373+
374+ if video_height * video_width > height * width :
375+ scale = min (width / video_width , height / video_height )
376+ video_height , video_width = int (video_height * scale ), int (video_width * scale )
377+
378+ if video_height % base != 0 or video_width % base != 0 :
379+ logger .warning (
380+ f"Video height and width should be divisible by { base } , but got { video_height } and { video_width } . "
381+ )
382+ video_height = (video_height // base ) * base
383+ video_width = (video_width // base ) * base
384+
385+ assert video_height * video_width <= height * width
386+
387+ video = self .video_processor .preprocess_video (video , video_height , video_width )
388+ image_size = (video_height , video_width ) # Use the height/width of video (with possible rescaling)
373389 else :
374- video = torch .zeros (batch_size , num_frames , 3 , height , width , dtype = dtype , device = device )
390+ video = torch .zeros (batch_size , 3 , num_frames , height , width , dtype = dtype , device = device )
375391 image_size = (height , width ) # Use the height/width provider by user
376392
377393 if mask is not None :
378- mask = self .video_processor .preprocess_video (mask , height , width )
394+ mask = self .video_processor .preprocess_video (mask , image_size [ 0 ], image_size [ 1 ] )
379395 else :
380- mask = torch .ones_like (video , dtype = dtype , device = device )
396+ mask = torch .ones_like (video )
381397
382398 video = video .to (dtype = dtype , device = device )
383399 mask = mask .to (dtype = dtype , device = device )
384400
385- reference_images_preprocessed = []
386- if reference_images is not None :
387- if not isinstance (reference_images , list ):
388- reference_images = [reference_images ]
389- for i , image in enumerate (reference_images ):
390- image = self .video_processor .preprocess (image , None , None ) # Use the height/width of image
401+ # Make a list of list of images where the outer list corresponds to video batch size and the inner list
402+ # corresponds to list of conditioning images per video
403+ if reference_images is None or isinstance (reference_images , PIL .Image .Image ):
404+ reference_images = [[reference_images ] for _ in range (video .shape [0 ])]
405+ elif isinstance (reference_images , (list , tuple )) and isinstance (next (iter (reference_images )), PIL .Image .Image ):
406+ reference_images = [reference_images ]
407+ elif (
408+ isinstance (reference_images , (list , tuple ))
409+ and isinstance (next (iter (reference_images )), list )
410+ and isinstance (next (iter (reference_images [0 ])), PIL .Image .Image )
411+ ):
412+ reference_images = reference_images
413+ else :
414+ raise ValueError (
415+ "`reference_images` has to be of type `PIL.Image.Image` or `list` of `PIL.Image.Image`, or "
416+ "`list` of `list` of `PIL.Image.Image`, but is {type(reference_images)}"
417+ )
418+
419+ if video .shape [0 ] != len (reference_images ):
420+ raise ValueError (
421+ f"Batch size of `video` { video .shape [0 ]} and length of `reference_images` { len (reference_images )} does not match."
422+ )
391423
424+ reference_images_preprocessed = []
425+ for i , reference_images_batch in enumerate (reference_images ):
426+ preprocessed_images = []
427+ for j , image in enumerate (reference_images_batch ):
428+ if image is None :
429+ continue
430+ image = self .video_processor .preprocess (image , None , None )
392431 img_height , img_width = image .shape [- 2 :]
393432 scale = min (image_size [0 ] / img_height , image_size [1 ] / img_width )
394433 new_height , new_width = int (img_height * scale ), int (img_width * scale )
395434 resized_image = torch .nn .functional .interpolate (
396- image .unsqueeze (1 ), size = (new_height , new_width ), mode = "bilinear" , align_corners = False
397- ).squeeze (1 )
398-
435+ image , size = (new_height , new_width ), mode = "bilinear" , align_corners = False
436+ ).squeeze (0 ) # [C, H, W]
399437 top = (image_size [0 ] - new_height ) // 2
400438 left = (image_size [1 ] - new_width ) // 2
401- canvas = torch .ones (batch_size , 1 , 3 , * image_size , device = device , dtype = dtype )
402- canvas [:, :, :, top : top + new_height , left : left + new_width ] = resized_image
403- reference_images_preprocessed .append (canvas )
439+ canvas = torch .ones (3 , * image_size , device = device , dtype = dtype )
440+ canvas [:, top : top + new_height , left : left + new_width ] = resized_image
441+ preprocessed_images .append (canvas )
442+ reference_images_preprocessed .append (preprocessed_images )
404443
405444 return video , mask , reference_images_preprocessed
406445
407446 def prepare_video_latents (
408447 self ,
409448 video : torch .Tensor ,
410449 mask : torch .Tensor ,
411- reference_images : Optional [List [torch .Tensor ]] = None ,
450+ reference_images : Optional [List [List [ torch .Tensor ] ]] = None ,
412451 generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
413452 ) -> torch .Tensor :
414453 if isinstance (generator , list ):
415454 # TODO: support this
416455 raise ValueError ("Passing a list of generators is not yet supported. This may be supported in the future." )
417456
418457 if reference_images is None :
419- # For each batch of video, we set no reference image (as one or more can be passed by user)
458+ # For each batch of video, we set no re
459+ # ference image (as one or more can be passed by user)
420460 reference_images = [[None ] for _ in range (video .shape [0 ])]
421461 else :
422462 if video .shape [0 ] != len (reference_images ):
@@ -437,22 +477,24 @@ def prepare_video_latents(
437477 latents = retrieve_latents (self .vae .encode (video ), generator , sample_mode = "argmax" ).unbind (0 )
438478 else :
439479 mask = mask .to (dtype = vae_dtype )
440- mask = [ torch .where (m > 0.5 , 1.0 , 0.0 ) for m in mask ]
441- inactive = [ v * (1 - m ) for v , m in zip ( video , mask )]
442- reactive = [ v * m for v , m in zip ( video , mask )]
480+ mask = torch .where (mask > 0.5 , 1.0 , 0.0 )
481+ inactive = video * (1 - mask )
482+ reactive = video * mask
443483 inactive = retrieve_latents (self .vae .encode (inactive ), generator , sample_mode = "argmax" )
444484 reactive = retrieve_latents (self .vae .encode (reactive ), generator , sample_mode = "argmax" )
445- latents = [ torch .cat ([i , r ], dim = 0 ) for i , r in zip ( inactive , reactive )]
485+ latents = torch .cat ([inactive , reactive ], dim = 1 )
446486
447487 latent_list = []
448- for latent , ref_images in zip (latents , reference_images ):
449- if ref_images is not None :
450- ref_images = ref_images .to (dtype = vae_dtype )
451- ref_latents = retrieve_latents (self .vae .encode (ref_images ), generator , sample_mode = "argmax" )
452- ref_latents = [torch .cat ([r , torch .zeros_like (r )], dim = 0 ) for r in ref_latents ]
453- latent = torch .cat ([* ref_latents , latent ], dim = 1 )
488+ for latent , reference_images_batch in zip (latents , reference_images ):
489+ for reference_image in reference_images_batch :
490+ assert reference_image .ndim == 3
491+ reference_image = reference_image .to (dtype = vae_dtype )
492+ reference_image = reference_image [None , :, None , :, :] # [1, C, 1, H, W]
493+ reference_latent = retrieve_latents (self .vae .encode (reference_image ), generator , sample_mode = "argmax" )
494+ reference_latent = torch .cat ([reference_latent , torch .zeros_like (reference_latent )], dim = 1 )
495+ latent = torch .cat ([reference_latent .squeeze (0 ), latent ], dim = 1 ) # Concat across frame dimension
454496 latent_list .append (latent )
455- return latent_list
497+ return torch . stack ( latent_list )
456498
457499 def prepare_masks (
458500 self ,
@@ -479,25 +521,28 @@ def prepare_masks(
479521 "Generating with more than one video is not yet supported. This may be supported in the future."
480522 )
481523
524+ transformer_patch_size = self .transformer .config .patch_size [1 ]
525+
482526 mask_list = []
483- transformer_patch_size = self .transformer .config .patch_size
484- for mask_ , ref_images in zip (mask , reference_images ):
485- num_frames , num_channels , height , width = mask_ .shape
527+ for mask_ , reference_images_batch in zip (mask , reference_images ):
528+ num_channels , num_frames , height , width = mask_ .shape
486529 new_num_frames = (num_frames + self .vae_scale_factor_temporal - 1 ) // self .vae_scale_factor_temporal
487530 new_height = height // (self .vae_scale_factor_spatial * transformer_patch_size ) * transformer_patch_size
488531 new_width = width // (self .vae_scale_factor_spatial * transformer_patch_size ) * transformer_patch_size
489- mask_ = mask_ [:, 0 , :, :]
490- mask_ = mask_ .view (num_frames , height , self .vae_scale_factor_spatial , width , self .vae_scale_factor_spatial )
491- mask_ = mask_ .permute (2 , 4 , 0 , 1 , 3 ).flatten (2 , 4 ).flatten (0 , 1 )
532+ mask_ = mask_ [0 , :, :, :]
533+ mask_ = mask_ .view (
534+ num_frames , new_height , self .vae_scale_factor_spatial , new_width , self .vae_scale_factor_spatial
535+ )
536+ mask_ = mask_ .permute (2 , 4 , 0 , 1 , 3 ).flatten (0 , 1 ) # [8x8, num_frames, new_height, new_width]
492537 mask_ = torch .nn .functional .interpolate (
493538 mask_ .unsqueeze (0 ), size = (new_num_frames , new_height , new_width ), mode = "nearest-exact"
494539 ).squeeze (0 )
495- if ref_images is not None :
496- num_ref_images = ref_images . size ( 0 )
497- mask_padding = torch .zeros_like (mask [: num_ref_images , :, :, :])
540+ num_ref_images = len ( reference_images_batch )
541+ if num_ref_images > 0 :
542+ mask_padding = torch .zeros_like (mask_ [: , :num_ref_images , :, :])
498543 mask_ = torch .cat ([mask_ , mask_padding ], dim = 1 )
499544 mask_list .append (mask_ )
500- return mask_list
545+ return torch . stack ( mask_list )
501546
502547 def prepare_latents (
503548 self ,
@@ -746,12 +791,9 @@ def __call__(
746791 )
747792
748793 conditioning_latents = self .prepare_video_latents (video , mask , reference_images , generator )
749- conditioning_latents = [c .to (transformer_dtype ) for c in conditioning_latents ]
750-
751794 mask = self .prepare_masks (mask , reference_images , generator )
752- mask = [m .to (transformer_dtype ) for m in mask ]
753-
754- conditioning_latents = [torch .cat ([c , m ], dim = 1 ) for c , m in zip (conditioning_latents , mask )]
795+ conditioning_latents = torch .cat ([conditioning_latents , mask ], dim = 1 )
796+ conditioning_latents = conditioning_latents .to (transformer_dtype )
755797
756798 num_channels_latents = self .transformer .config .in_channels
757799 latents = self .prepare_latents (
0 commit comments