Skip to content

Commit c004327

Browse files
bigcat88adlerfaulkner
authored andcommitted
convert nodes_hunyuan.py to V3 schema (comfyanonymous#10136)
1 parent ae79281 commit c004327

File tree

1 file changed

+153
-94
lines changed

1 file changed

+153
-94
lines changed

comfy_extras/nodes_hunyuan.py

Lines changed: 153 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -2,42 +2,60 @@
22
import node_helpers
33
import torch
44
import comfy.model_management
5+
from typing_extensions import override
6+
from comfy_api.latest import ComfyExtension, io
57

68

7-
class CLIPTextEncodeHunyuanDiT:
9+
class CLIPTextEncodeHunyuanDiT(io.ComfyNode):
810
@classmethod
9-
def INPUT_TYPES(s):
10-
return {"required": {
11-
"clip": ("CLIP", ),
12-
"bert": ("STRING", {"multiline": True, "dynamicPrompts": True}),
13-
"mt5xl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
14-
}}
15-
RETURN_TYPES = ("CONDITIONING",)
16-
FUNCTION = "encode"
17-
18-
CATEGORY = "advanced/conditioning"
19-
20-
def encode(self, clip, bert, mt5xl):
11+
def define_schema(cls):
12+
return io.Schema(
13+
node_id="CLIPTextEncodeHunyuanDiT",
14+
category="advanced/conditioning",
15+
inputs=[
16+
io.Clip.Input("clip"),
17+
io.String.Input("bert", multiline=True, dynamic_prompts=True),
18+
io.String.Input("mt5xl", multiline=True, dynamic_prompts=True),
19+
],
20+
outputs=[
21+
io.Conditioning.Output(),
22+
],
23+
)
24+
25+
@classmethod
26+
def execute(cls, clip, bert, mt5xl) -> io.NodeOutput:
2127
tokens = clip.tokenize(bert)
2228
tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"]
2329

24-
return (clip.encode_from_tokens_scheduled(tokens), )
30+
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
2531

26-
class EmptyHunyuanLatentVideo:
27-
@classmethod
28-
def INPUT_TYPES(s):
29-
return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
30-
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
31-
"length": ("INT", {"default": 25, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
32-
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
33-
RETURN_TYPES = ("LATENT",)
34-
FUNCTION = "generate"
32+
encode = execute # TODO: remove
3533

36-
CATEGORY = "latent/video"
3734

38-
def generate(self, width, height, length, batch_size=1):
35+
class EmptyHunyuanLatentVideo(io.ComfyNode):
36+
@classmethod
37+
def define_schema(cls):
38+
return io.Schema(
39+
node_id="EmptyHunyuanLatentVideo",
40+
category="latent/video",
41+
inputs=[
42+
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
43+
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
44+
io.Int.Input("length", default=25, min=1, max=nodes.MAX_RESOLUTION, step=4),
45+
io.Int.Input("batch_size", default=1, min=1, max=4096),
46+
],
47+
outputs=[
48+
io.Latent.Output(),
49+
],
50+
)
51+
52+
@classmethod
53+
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
3954
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
40-
return ({"samples":latent}, )
55+
return io.NodeOutput({"samples":latent})
56+
57+
generate = execute # TODO: remove
58+
4159

4260
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
4361
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
@@ -50,45 +68,61 @@ def generate(self, width, height, length, batch_size=1):
5068
"<|start_header_id|>assistant<|end_header_id|>\n\n"
5169
)
5270

53-
class TextEncodeHunyuanVideo_ImageToVideo:
71+
class TextEncodeHunyuanVideo_ImageToVideo(io.ComfyNode):
72+
@classmethod
73+
def define_schema(cls):
74+
return io.Schema(
75+
node_id="TextEncodeHunyuanVideo_ImageToVideo",
76+
category="advanced/conditioning",
77+
inputs=[
78+
io.Clip.Input("clip"),
79+
io.ClipVisionOutput.Input("clip_vision_output"),
80+
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
81+
io.Int.Input(
82+
"image_interleave",
83+
default=2,
84+
min=1,
85+
max=512,
86+
tooltip="How much the image influences things vs the text prompt. Higher number means more influence from the text prompt.",
87+
),
88+
],
89+
outputs=[
90+
io.Conditioning.Output(),
91+
],
92+
)
93+
5494
@classmethod
55-
def INPUT_TYPES(s):
56-
return {"required": {
57-
"clip": ("CLIP", ),
58-
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
59-
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
60-
"image_interleave": ("INT", {"default": 2, "min": 1, "max": 512, "tooltip": "How much the image influences things vs the text prompt. Higher number means more influence from the text prompt."}),
61-
}}
62-
RETURN_TYPES = ("CONDITIONING",)
63-
FUNCTION = "encode"
64-
65-
CATEGORY = "advanced/conditioning"
66-
67-
def encode(self, clip, clip_vision_output, prompt, image_interleave):
95+
def execute(cls, clip, clip_vision_output, prompt, image_interleave) -> io.NodeOutput:
6896
tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected, image_interleave=image_interleave)
69-
return (clip.encode_from_tokens_scheduled(tokens), )
97+
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
98+
99+
encode = execute # TODO: remove
100+
70101

71-
class HunyuanImageToVideo:
102+
class HunyuanImageToVideo(io.ComfyNode):
72103
@classmethod
73-
def INPUT_TYPES(s):
74-
return {"required": {"positive": ("CONDITIONING", ),
75-
"vae": ("VAE", ),
76-
"width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
77-
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
78-
"length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
79-
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
80-
"guidance_type": (["v1 (concat)", "v2 (replace)", "custom"], )
81-
},
82-
"optional": {"start_image": ("IMAGE", ),
83-
}}
84-
85-
RETURN_TYPES = ("CONDITIONING", "LATENT")
86-
RETURN_NAMES = ("positive", "latent")
87-
FUNCTION = "encode"
88-
89-
CATEGORY = "conditioning/video_models"
90-
91-
def encode(self, positive, vae, width, height, length, batch_size, guidance_type, start_image=None):
104+
def define_schema(cls):
105+
return io.Schema(
106+
node_id="HunyuanImageToVideo",
107+
category="conditioning/video_models",
108+
inputs=[
109+
io.Conditioning.Input("positive"),
110+
io.Vae.Input("vae"),
111+
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
112+
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
113+
io.Int.Input("length", default=53, min=1, max=nodes.MAX_RESOLUTION, step=4),
114+
io.Int.Input("batch_size", default=1, min=1, max=4096),
115+
io.Combo.Input("guidance_type", options=["v1 (concat)", "v2 (replace)", "custom"]),
116+
io.Image.Input("start_image", optional=True),
117+
],
118+
outputs=[
119+
io.Conditioning.Output(display_name="positive"),
120+
io.Latent.Output(display_name="latent"),
121+
],
122+
)
123+
124+
@classmethod
125+
def execute(cls, positive, vae, width, height, length, batch_size, guidance_type, start_image=None) -> io.NodeOutput:
92126
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
93127
out_latent = {}
94128

@@ -111,51 +145,76 @@ def encode(self, positive, vae, width, height, length, batch_size, guidance_type
111145
positive = node_helpers.conditioning_set_values(positive, cond)
112146

113147
out_latent["samples"] = latent
114-
return (positive, out_latent)
148+
return io.NodeOutput(positive, out_latent)
115149

116-
class EmptyHunyuanImageLatent:
117-
@classmethod
118-
def INPUT_TYPES(s):
119-
return {"required": { "width": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
120-
"height": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
121-
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
122-
RETURN_TYPES = ("LATENT",)
123-
FUNCTION = "generate"
150+
encode = execute # TODO: remove
124151

125-
CATEGORY = "latent"
126152

127-
def generate(self, width, height, batch_size=1):
128-
latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device())
129-
return ({"samples":latent}, )
153+
class EmptyHunyuanImageLatent(io.ComfyNode):
154+
@classmethod
155+
def define_schema(cls):
156+
return io.Schema(
157+
node_id="EmptyHunyuanImageLatent",
158+
category="latent",
159+
inputs=[
160+
io.Int.Input("width", default=2048, min=64, max=nodes.MAX_RESOLUTION, step=32),
161+
io.Int.Input("height", default=2048, min=64, max=nodes.MAX_RESOLUTION, step=32),
162+
io.Int.Input("batch_size", default=1, min=1, max=4096),
163+
],
164+
outputs=[
165+
io.Latent.Output(),
166+
],
167+
)
130168

131-
class HunyuanRefinerLatent:
132169
@classmethod
133-
def INPUT_TYPES(s):
134-
return {"required": {"positive": ("CONDITIONING", ),
135-
"negative": ("CONDITIONING", ),
136-
"latent": ("LATENT", ),
137-
"noise_augmentation": ("FLOAT", {"default": 0.10, "min": 0.0, "max": 1.0, "step": 0.01}),
138-
}}
170+
def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
171+
latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device())
172+
return io.NodeOutput({"samples":latent})
139173

140-
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
141-
RETURN_NAMES = ("positive", "negative", "latent")
174+
generate = execute # TODO: remove
142175

143-
FUNCTION = "execute"
144176

145-
def execute(self, positive, negative, latent, noise_augmentation):
177+
class HunyuanRefinerLatent(io.ComfyNode):
178+
@classmethod
179+
def define_schema(cls):
180+
return io.Schema(
181+
node_id="HunyuanRefinerLatent",
182+
inputs=[
183+
io.Conditioning.Input("positive"),
184+
io.Conditioning.Input("negative"),
185+
io.Latent.Input("latent"),
186+
io.Float.Input("noise_augmentation", default=0.10, min=0.0, max=1.0, step=0.01),
187+
188+
],
189+
outputs=[
190+
io.Conditioning.Output(display_name="positive"),
191+
io.Conditioning.Output(display_name="negative"),
192+
io.Latent.Output(display_name="latent"),
193+
],
194+
)
195+
196+
@classmethod
197+
def execute(cls, positive, negative, latent, noise_augmentation) -> io.NodeOutput:
146198
latent = latent["samples"]
147199
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation})
148200
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation})
149201
out_latent = {}
150202
out_latent["samples"] = torch.zeros([latent.shape[0], 32, latent.shape[-3], latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
151-
return (positive, negative, out_latent)
203+
return io.NodeOutput(positive, negative, out_latent)
204+
205+
206+
class HunyuanExtension(ComfyExtension):
207+
@override
208+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
209+
return [
210+
CLIPTextEncodeHunyuanDiT,
211+
TextEncodeHunyuanVideo_ImageToVideo,
212+
EmptyHunyuanLatentVideo,
213+
HunyuanImageToVideo,
214+
EmptyHunyuanImageLatent,
215+
HunyuanRefinerLatent,
216+
]
152217

153218

154-
NODE_CLASS_MAPPINGS = {
155-
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
156-
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
157-
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
158-
"HunyuanImageToVideo": HunyuanImageToVideo,
159-
"EmptyHunyuanImageLatent": EmptyHunyuanImageLatent,
160-
"HunyuanRefinerLatent": HunyuanRefinerLatent,
161-
}
219+
async def comfy_entrypoint() -> HunyuanExtension:
220+
return HunyuanExtension()

0 commit comments

Comments
 (0)