Skip to content

Commit 0c8d878

Browse files
Update WanAnimateToVideo to more easily extend videos. (comfyanonymous#9959)
1 parent 919517c commit 0c8d878

File tree

2 files changed

+45
-16
lines changed

2 files changed

+45
-16
lines changed

comfy/ldm/wan/model_animate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ def __init__(self,
451451
def after_patch_embedding(self, x, pose_latents, face_pixel_values):
452452
if pose_latents is not None:
453453
pose_latents = self.pose_patch_embedding(pose_latents)
454-
x[:, :, 1:] += pose_latents
454+
x[:, :, 1:pose_latents.shape[2] + 1] += pose_latents[:, :, :x.shape[2] - 1]
455455

456456
if face_pixel_values is None:
457457
return x, None

comfy_extras/nodes_wan.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,18 +1128,22 @@ def define_schema(cls):
11281128
io.Image.Input("pose_video", optional=True),
11291129
io.Int.Input("continue_motion_max_frames", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4),
11301130
io.Image.Input("continue_motion", optional=True),
1131+
io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="The amount of frames to seek in all the input videos. Used for generating longer videos by chunk. Connect to the video_frame_offset output of the previous node for extending a video."),
11311132
],
11321133
outputs=[
11331134
io.Conditioning.Output(display_name="positive"),
11341135
io.Conditioning.Output(display_name="negative"),
11351136
io.Latent.Output(display_name="latent"),
11361137
io.Int.Output(display_name="trim_latent"),
1138+
io.Int.Output(display_name="trim_image"),
1139+
io.Int.Output(display_name="video_frame_offset"),
11371140
],
11381141
is_experimental=True,
11391142
)
11401143

11411144
@classmethod
1142-
def execute(cls, positive, negative, vae, width, height, length, batch_size, continue_motion_max_frames, reference_image=None, clip_vision_output=None, face_video=None, pose_video=None, continue_motion=None) -> io.NodeOutput:
1145+
def execute(cls, positive, negative, vae, width, height, length, batch_size, continue_motion_max_frames, video_frame_offset, reference_image=None, clip_vision_output=None, face_video=None, pose_video=None, continue_motion=None) -> io.NodeOutput:
1146+
trim_to_pose_video = False
11431147
latent_length = ((length - 1) // 4) + 1
11441148
latent_width = width // 8
11451149
latent_height = height // 8
@@ -1152,35 +1156,60 @@ def execute(cls, positive, negative, vae, width, height, length, batch_size, con
11521156
concat_latent_image = vae.encode(image[:, :, :, :3])
11531157
mask = torch.zeros((1, 1, concat_latent_image.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=concat_latent_image.device, dtype=concat_latent_image.dtype)
11541158
trim_latent += concat_latent_image.shape[2]
1159+
ref_motion_latent_length = 0
1160+
1161+
if continue_motion is None:
1162+
image = torch.ones((length, height, width, 3)) * 0.5
1163+
else:
1164+
continue_motion = continue_motion[-continue_motion_max_frames:]
1165+
video_frame_offset -= continue_motion.shape[0]
1166+
video_frame_offset = max(0, video_frame_offset)
1167+
continue_motion = comfy.utils.common_upscale(continue_motion[-length:].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
1168+
image = torch.ones((length, height, width, continue_motion.shape[-1]), device=continue_motion.device, dtype=continue_motion.dtype) * 0.5
1169+
image[:continue_motion.shape[0]] = continue_motion
1170+
ref_motion_latent_length += ((continue_motion.shape[0] - 1) // 4) + 1
11551171

11561172
if clip_vision_output is not None:
11571173
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
11581174
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
11591175

1160-
if face_video is not None:
1161-
face_video = comfy.utils.common_upscale(face_video[:length].movedim(-1, 1), 512, 512, "area", "center") * 2.0 - 1.0
1162-
face_video = face_video.movedim(0, 1).unsqueeze(0)
1163-
positive = node_helpers.conditioning_set_values(positive, {"face_video_pixels": face_video})
1164-
negative = node_helpers.conditioning_set_values(negative, {"face_video_pixels": face_video * 0.0 - 1.0})
1176+
if pose_video is not None:
1177+
if pose_video.shape[0] <= video_frame_offset:
1178+
pose_video = None
1179+
else:
1180+
pose_video = pose_video[video_frame_offset:]
11651181

11661182
if pose_video is not None:
11671183
pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
1184+
if not trim_to_pose_video:
1185+
if pose_video.shape[0] < length:
1186+
pose_video = torch.cat((pose_video,) + (pose_video[-1:],) * (length - pose_video.shape[0]), dim=0)
1187+
11681188
pose_video_latent = vae.encode(pose_video[:, :, :, :3])
11691189
positive = node_helpers.conditioning_set_values(positive, {"pose_video_latent": pose_video_latent})
11701190
negative = node_helpers.conditioning_set_values(negative, {"pose_video_latent": pose_video_latent})
11711191

1172-
if continue_motion is None:
1173-
image = torch.ones((length, height, width, 3)) * 0.5
1174-
else:
1175-
continue_motion = continue_motion[-continue_motion_max_frames:]
1176-
continue_motion = comfy.utils.common_upscale(continue_motion[-length:].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
1177-
image = torch.ones((length, height, width, continue_motion.shape[-1]), device=continue_motion.device, dtype=continue_motion.dtype) * 0.5
1178-
image[:continue_motion.shape[0]] = continue_motion
1192+
if trim_to_pose_video:
1193+
latent_length = pose_video_latent.shape[2]
1194+
length = latent_length * 4 - 3
1195+
image = image[:length]
1196+
1197+
if face_video is not None:
1198+
if face_video.shape[0] <= video_frame_offset:
1199+
face_video = None
1200+
else:
1201+
face_video = face_video[video_frame_offset:]
1202+
1203+
if face_video is not None:
1204+
face_video = comfy.utils.common_upscale(face_video[:length].movedim(-1, 1), 512, 512, "area", "center") * 2.0 - 1.0
1205+
face_video = face_video.movedim(0, 1).unsqueeze(0)
1206+
positive = node_helpers.conditioning_set_values(positive, {"face_video_pixels": face_video})
1207+
negative = node_helpers.conditioning_set_values(negative, {"face_video_pixels": face_video * 0.0 - 1.0})
11791208

11801209
concat_latent_image = torch.cat((concat_latent_image, vae.encode(image[:, :, :, :3])), dim=2)
11811210
mask_refmotion = torch.ones((1, 1, latent_length, concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=mask.device, dtype=mask.dtype)
11821211
if continue_motion is not None:
1183-
mask_refmotion[:, :, :((continue_motion.shape[0] - 1) // 4) + 1] = 0.0
1212+
mask_refmotion[:, :, :ref_motion_latent_length] = 0.0
11841213

11851214
mask = torch.cat((mask, mask_refmotion), dim=2)
11861215
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
@@ -1189,7 +1218,7 @@ def execute(cls, positive, negative, vae, width, height, length, batch_size, con
11891218
latent = torch.zeros([batch_size, 16, latent_length + trim_latent, latent_height, latent_width], device=comfy.model_management.intermediate_device())
11901219
out_latent = {}
11911220
out_latent["samples"] = latent
1192-
return io.NodeOutput(positive, negative, out_latent, trim_latent)
1221+
return io.NodeOutput(positive, negative, out_latent, trim_latent, max(0, ref_motion_latent_length * 4 - 3), video_frame_offset + length)
11931222

11941223
class Wan22ImageToVideoLatent(io.ComfyNode):
11951224
@classmethod

0 commit comments

Comments
 (0)