Skip to content

Commit c730faf

Browse files
comfyanonymousgmaOCR
authored andcommitted
Add inputs for character replacement to the WanAnimateToVideo node. (comfyanonymous#9960)
1 parent ba217f7 commit c730faf

File tree

1 file changed

+33
-5
lines changed

1 file changed

+33
-5
lines changed

comfy_extras/nodes_wan.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,6 +1127,8 @@ def define_schema(cls):
11271127
io.Image.Input("face_video", optional=True),
11281128
io.Image.Input("pose_video", optional=True),
11291129
io.Int.Input("continue_motion_max_frames", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4),
1130+
io.Image.Input("background_video", optional=True),
1131+
io.Mask.Input("character_mask", optional=True),
11301132
io.Image.Input("continue_motion", optional=True),
11311133
io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="The amount of frames to seek in all the input videos. Used for generating longer videos by chunk. Connect to the video_frame_offset output of the previous node for extending a video."),
11321134
],
@@ -1142,7 +1144,7 @@ def define_schema(cls):
11421144
)
11431145

11441146
@classmethod
1145-
def execute(cls, positive, negative, vae, width, height, length, batch_size, continue_motion_max_frames, video_frame_offset, reference_image=None, clip_vision_output=None, face_video=None, pose_video=None, continue_motion=None) -> io.NodeOutput:
1147+
def execute(cls, positive, negative, vae, width, height, length, batch_size, continue_motion_max_frames, video_frame_offset, reference_image=None, clip_vision_output=None, face_video=None, pose_video=None, continue_motion=None, background_video=None, character_mask=None) -> io.NodeOutput:
11461148
trim_to_pose_video = False
11471149
latent_length = ((length - 1) // 4) + 1
11481150
latent_width = width // 8
@@ -1154,7 +1156,7 @@ def execute(cls, positive, negative, vae, width, height, length, batch_size, con
11541156

11551157
image = comfy.utils.common_upscale(reference_image[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
11561158
concat_latent_image = vae.encode(image[:, :, :, :3])
1157-
mask = torch.zeros((1, 1, concat_latent_image.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=concat_latent_image.device, dtype=concat_latent_image.dtype)
1159+
mask = torch.zeros((1, 4, concat_latent_image.shape[-3], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=concat_latent_image.device, dtype=concat_latent_image.dtype)
11581160
trim_latent += concat_latent_image.shape[2]
11591161
ref_motion_latent_length = 0
11601162

@@ -1206,11 +1208,37 @@ def execute(cls, positive, negative, vae, width, height, length, batch_size, con
12061208
positive = node_helpers.conditioning_set_values(positive, {"face_video_pixels": face_video})
12071209
negative = node_helpers.conditioning_set_values(negative, {"face_video_pixels": face_video * 0.0 - 1.0})
12081210

1209-
concat_latent_image = torch.cat((concat_latent_image, vae.encode(image[:, :, :, :3])), dim=2)
1210-
mask_refmotion = torch.ones((1, 1, latent_length, concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=mask.device, dtype=mask.dtype)
1211+
ref_images_num = max(0, ref_motion_latent_length * 4 - 3)
1212+
if background_video is not None:
1213+
if background_video.shape[0] > video_frame_offset:
1214+
background_video = background_video[video_frame_offset:]
1215+
background_video = comfy.utils.common_upscale(background_video[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
1216+
if background_video.shape[0] > ref_images_num:
1217+
image[ref_images_num:background_video.shape[0] - ref_images_num] = background_video[ref_images_num:]
1218+
1219+
mask_refmotion = torch.ones((1, 1, latent_length * 4, concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=mask.device, dtype=mask.dtype)
12111220
if continue_motion is not None:
1212-
mask_refmotion[:, :, :ref_motion_latent_length] = 0.0
1221+
mask_refmotion[:, :, :ref_motion_latent_length * 4] = 0.0
1222+
1223+
if character_mask is not None:
1224+
if character_mask.shape[0] > video_frame_offset or character_mask.shape[0] == 1:
1225+
if character_mask.shape[0] == 1:
1226+
character_mask = character_mask.repeat((length,) + (1,) * (character_mask.ndim - 1))
1227+
else:
1228+
character_mask = character_mask[video_frame_offset:]
1229+
if character_mask.ndim == 3:
1230+
character_mask = character_mask.unsqueeze(1)
1231+
character_mask = character_mask.movedim(0, 1)
1232+
if character_mask.ndim == 4:
1233+
character_mask = character_mask.unsqueeze(1)
1234+
character_mask = comfy.utils.common_upscale(character_mask[:, :, :length], concat_latent_image.shape[-1], concat_latent_image.shape[-2], "nearest-exact", "center")
1235+
if character_mask.shape[2] > ref_images_num:
1236+
mask_refmotion[:, :, ref_images_num:character_mask.shape[2] + ref_images_num] = character_mask[:, :, ref_images_num:]
1237+
1238+
concat_latent_image = torch.cat((concat_latent_image, vae.encode(image[:, :, :, :3])), dim=2)
1239+
12131240

1241+
mask_refmotion = mask_refmotion.view(1, mask_refmotion.shape[2] // 4, 4, mask_refmotion.shape[3], mask_refmotion.shape[4]).transpose(1, 2)
12141242
mask = torch.cat((mask, mask_refmotion), dim=2)
12151243
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
12161244
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})

0 commit comments

Comments
 (0)