|  | 
| 3 | 3 | 
 | 
| 4 | 4 | #My modified one here is more basic but has less chances of breaking with ComfyUI updates. | 
| 5 | 5 | 
 | 
|  | 6 | +from typing_extensions import override | 
|  | 7 | + | 
| 6 | 8 | import comfy.model_patcher | 
| 7 | 9 | import comfy.samplers | 
|  | 10 | +from comfy_api.latest import ComfyExtension, io | 
| 8 | 11 | 
 | 
| 9 |  | -class PerturbedAttentionGuidance: | 
| 10 |  | -    @classmethod | 
| 11 |  | -    def INPUT_TYPES(s): | 
| 12 |  | -        return { | 
| 13 |  | -            "required": { | 
| 14 |  | -                "model": ("MODEL",), | 
| 15 |  | -                "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01}), | 
| 16 |  | -            } | 
| 17 |  | -        } | 
| 18 |  | - | 
| 19 |  | -    RETURN_TYPES = ("MODEL",) | 
| 20 |  | -    FUNCTION = "patch" | 
| 21 | 12 | 
 | 
| 22 |  | -    CATEGORY = "model_patches/unet" | 
|  | 13 | +class PerturbedAttentionGuidance(io.ComfyNode): | 
|  | 14 | +    @classmethod | 
|  | 15 | +    def define_schema(cls): | 
|  | 16 | +        return io.Schema( | 
|  | 17 | +            node_id="PerturbedAttentionGuidance", | 
|  | 18 | +            category="model_patches/unet", | 
|  | 19 | +            inputs=[ | 
|  | 20 | +                io.Model.Input("model"), | 
|  | 21 | +                io.Float.Input("scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01), | 
|  | 22 | +            ], | 
|  | 23 | +            outputs=[ | 
|  | 24 | +                io.Model.Output(), | 
|  | 25 | +            ], | 
|  | 26 | +        ) | 
| 23 | 27 | 
 | 
| 24 |  | -    def patch(self, model, scale): | 
|  | 28 | +    @classmethod | 
|  | 29 | +    def execute(cls, model, scale) -> io.NodeOutput: | 
| 25 | 30 |         unet_block = "middle" | 
| 26 | 31 |         unet_block_id = 0 | 
| 27 | 32 |         m = model.clone() | 
| @@ -49,8 +54,16 @@ def post_cfg_function(args): | 
| 49 | 54 | 
 | 
| 50 | 55 |         m.set_model_sampler_post_cfg_function(post_cfg_function) | 
| 51 | 56 | 
 | 
| 52 |  | -        return (m,) | 
|  | 57 | +        return io.NodeOutput(m) | 
|  | 58 | + | 
|  | 59 | + | 
|  | 60 | +class PAGExtension(ComfyExtension): | 
|  | 61 | +    @override | 
|  | 62 | +    async def get_node_list(self) -> list[type[io.ComfyNode]]: | 
|  | 63 | +        return [ | 
|  | 64 | +            PerturbedAttentionGuidance, | 
|  | 65 | +        ] | 
|  | 66 | + | 
| 53 | 67 | 
 | 
| 54 |  | -NODE_CLASS_MAPPINGS = { | 
| 55 |  | -    "PerturbedAttentionGuidance": PerturbedAttentionGuidance, | 
| 56 |  | -} | 
|  | 68 | +async def comfy_entrypoint() -> PAGExtension: | 
|  | 69 | +    return PAGExtension() | 
0 commit comments