Skip to content

Commit 6d71bff

Browse files
bigcat88gmaOCR
authored andcommitted
convert nodes_pag.py to V3 schema (comfyanonymous#10080)
1 parent acce273 commit 6d71bff

File tree

1 file changed

+31
-18
lines changed

1 file changed

+31
-18
lines changed

comfy_extras/nodes_pag.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,30 @@
33

44
#My modified one here is more basic but has less chances of breaking with ComfyUI updates.
55

6+
from typing_extensions import override
7+
68
import comfy.model_patcher
79
import comfy.samplers
10+
from comfy_api.latest import ComfyExtension, io
811

9-
class PerturbedAttentionGuidance:
10-
@classmethod
11-
def INPUT_TYPES(s):
12-
return {
13-
"required": {
14-
"model": ("MODEL",),
15-
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01}),
16-
}
17-
}
18-
19-
RETURN_TYPES = ("MODEL",)
20-
FUNCTION = "patch"
2112

22-
CATEGORY = "model_patches/unet"
13+
class PerturbedAttentionGuidance(io.ComfyNode):
14+
@classmethod
15+
def define_schema(cls):
16+
return io.Schema(
17+
node_id="PerturbedAttentionGuidance",
18+
category="model_patches/unet",
19+
inputs=[
20+
io.Model.Input("model"),
21+
io.Float.Input("scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01),
22+
],
23+
outputs=[
24+
io.Model.Output(),
25+
],
26+
)
2327

24-
def patch(self, model, scale):
28+
@classmethod
29+
def execute(cls, model, scale) -> io.NodeOutput:
2530
unet_block = "middle"
2631
unet_block_id = 0
2732
m = model.clone()
@@ -49,8 +54,16 @@ def post_cfg_function(args):
4954

5055
m.set_model_sampler_post_cfg_function(post_cfg_function)
5156

52-
return (m,)
57+
return io.NodeOutput(m)
58+
59+
60+
class PAGExtension(ComfyExtension):
61+
@override
62+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
63+
return [
64+
PerturbedAttentionGuidance,
65+
]
66+
5367

54-
NODE_CLASS_MAPPINGS = {
55-
"PerturbedAttentionGuidance": PerturbedAttentionGuidance,
56-
}
68+
async def comfy_entrypoint() -> PAGExtension:
69+
return PAGExtension()

0 commit comments

Comments
 (0)