@@ -1127,6 +1127,8 @@ def define_schema(cls):
11271127                io .Image .Input ("face_video" , optional = True ),
11281128                io .Image .Input ("pose_video" , optional = True ),
11291129                io .Int .Input ("continue_motion_max_frames" , default = 5 , min = 1 , max = nodes .MAX_RESOLUTION , step = 4 ),
1130+                 io .Image .Input ("background_video" , optional = True ),
1131+                 io .Mask .Input ("character_mask" , optional = True ),
11301132                io .Image .Input ("continue_motion" , optional = True ),
11311133                io .Int .Input ("video_frame_offset" , default = 0 , min = 0 , max = nodes .MAX_RESOLUTION , step = 1 , tooltip = "The amount of frames to seek in all the input videos. Used for generating longer videos by chunk. Connect to the video_frame_offset output of the previous node for extending a video." ),
11321134            ],
@@ -1142,7 +1144,7 @@ def define_schema(cls):
11421144        )
11431145
11441146    @classmethod  
1145-     def  execute (cls , positive , negative , vae , width , height , length , batch_size , continue_motion_max_frames , video_frame_offset , reference_image = None , clip_vision_output = None , face_video = None , pose_video = None , continue_motion = None ) ->  io .NodeOutput :
1147+     def  execute (cls , positive , negative , vae , width , height , length , batch_size , continue_motion_max_frames , video_frame_offset , reference_image = None , clip_vision_output = None , face_video = None , pose_video = None , continue_motion = None ,  background_video = None ,  character_mask = None ) ->  io .NodeOutput :
11461148        trim_to_pose_video  =  False 
11471149        latent_length  =  ((length  -  1 ) //  4 ) +  1 
11481150        latent_width  =  width  //  8 
@@ -1154,7 +1156,7 @@ def execute(cls, positive, negative, vae, width, height, length, batch_size, con
11541156
11551157        image  =  comfy .utils .common_upscale (reference_image [:length ].movedim (- 1 , 1 ), width , height , "area" , "center" ).movedim (1 , - 1 )
11561158        concat_latent_image  =  vae .encode (image [:, :, :, :3 ])
1157-         mask  =  torch .zeros ((1 , 1 , concat_latent_image .shape [2 ], concat_latent_image .shape [- 2 ], concat_latent_image .shape [- 1 ]), device = concat_latent_image .device , dtype = concat_latent_image .dtype )
1159+         mask  =  torch .zeros ((1 , 4 , concat_latent_image .shape [- 3 ], concat_latent_image .shape [- 2 ], concat_latent_image .shape [- 1 ]), device = concat_latent_image .device , dtype = concat_latent_image .dtype )
11581160        trim_latent  +=  concat_latent_image .shape [2 ]
11591161        ref_motion_latent_length  =  0 
11601162
@@ -1206,11 +1208,37 @@ def execute(cls, positive, negative, vae, width, height, length, batch_size, con
12061208            positive  =  node_helpers .conditioning_set_values (positive , {"face_video_pixels" : face_video })
12071209            negative  =  node_helpers .conditioning_set_values (negative , {"face_video_pixels" : face_video  *  0.0  -  1.0 })
12081210
1209-         concat_latent_image  =  torch .cat ((concat_latent_image , vae .encode (image [:, :, :, :3 ])), dim = 2 )
1210-         mask_refmotion  =  torch .ones ((1 , 1 , latent_length , concat_latent_image .shape [- 2 ], concat_latent_image .shape [- 1 ]), device = mask .device , dtype = mask .dtype )
1211+         ref_images_num  =  max (0 , ref_motion_latent_length  *  4  -  3 )
1212+         if  background_video  is  not None :
1213+             if  background_video .shape [0 ] >  video_frame_offset :
1214+                 background_video  =  background_video [video_frame_offset :]
1215+                 background_video  =  comfy .utils .common_upscale (background_video [:length ].movedim (- 1 , 1 ), width , height , "area" , "center" ).movedim (1 , - 1 )
1216+                 if  background_video .shape [0 ] >  ref_images_num :
1217+                     image [ref_images_num :background_video .shape [0 ] -  ref_images_num ] =  background_video [ref_images_num :]
1218+ 
1219+         mask_refmotion  =  torch .ones ((1 , 1 , latent_length  *  4 , concat_latent_image .shape [- 2 ], concat_latent_image .shape [- 1 ]), device = mask .device , dtype = mask .dtype )
12111220        if  continue_motion  is  not None :
1212-             mask_refmotion [:, :, :ref_motion_latent_length ] =  0.0 
1221+             mask_refmotion [:, :, :ref_motion_latent_length  *  4 ] =  0.0 
1222+ 
1223+         if  character_mask  is  not None :
1224+             if  character_mask .shape [0 ] >  video_frame_offset  or  character_mask .shape [0 ] ==  1 :
1225+                 if  character_mask .shape [0 ] ==  1 :
1226+                     character_mask  =  character_mask .repeat ((length ,) +  (1 ,) *  (character_mask .ndim  -  1 ))
1227+                 else :
1228+                     character_mask  =  character_mask [video_frame_offset :]
1229+                 if  character_mask .ndim  ==  3 :
1230+                     character_mask  =  character_mask .unsqueeze (1 )
1231+                     character_mask  =  character_mask .movedim (0 , 1 )
1232+                 if  character_mask .ndim  ==  4 :
1233+                     character_mask  =  character_mask .unsqueeze (1 )
1234+                 character_mask  =  comfy .utils .common_upscale (character_mask [:, :, :length ], concat_latent_image .shape [- 1 ], concat_latent_image .shape [- 2 ], "nearest-exact" , "center" )
1235+                 if  character_mask .shape [2 ] >  ref_images_num :
1236+                     mask_refmotion [:, :, ref_images_num :character_mask .shape [2 ] +  ref_images_num ] =  character_mask [:, :, ref_images_num :]
1237+ 
1238+         concat_latent_image  =  torch .cat ((concat_latent_image , vae .encode (image [:, :, :, :3 ])), dim = 2 )
1239+ 
12131240
1241+         mask_refmotion  =  mask_refmotion .view (1 , mask_refmotion .shape [2 ] //  4 , 4 , mask_refmotion .shape [3 ], mask_refmotion .shape [4 ]).transpose (1 , 2 )
12141242        mask  =  torch .cat ((mask , mask_refmotion ), dim = 2 )
12151243        positive  =  node_helpers .conditioning_set_values (positive , {"concat_latent_image" : concat_latent_image , "concat_mask" : mask })
12161244        negative  =  node_helpers .conditioning_set_values (negative , {"concat_latent_image" : concat_latent_image , "concat_mask" : mask })
0 commit comments