|
1 | 1 | #Taken from: https://github.com/dbolya/tomesd |
2 | 2 |
|
3 | 3 | import torch |
4 | | -from typing import Tuple, Callable |
| 4 | +from typing import Tuple, Callable, Optional |
| 5 | +from typing_extensions import override |
| 6 | +from comfy_api.latest import ComfyExtension, io |
5 | 7 | import math |
6 | 8 |
|
7 | 9 | def do_nothing(x: torch.Tensor, mode:str=None): |
@@ -144,33 +146,45 @@ def get_functions(x, ratio, original_shape): |
144 | 146 |
|
145 | 147 |
|
146 | 148 |
|
147 | | -class TomePatchModel: |
| 149 | +class TomePatchModel(io.ComfyNode): |
148 | 150 | @classmethod |
149 | | - def INPUT_TYPES(s): |
150 | | - return {"required": { "model": ("MODEL",), |
151 | | - "ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}), |
152 | | - }} |
153 | | - RETURN_TYPES = ("MODEL",) |
154 | | - FUNCTION = "patch" |
| 151 | + def define_schema(cls): |
| 152 | + return io.Schema( |
| 153 | + node_id="TomePatchModel", |
| 154 | + category="model_patches/unet", |
| 155 | + inputs=[ |
| 156 | + io.Model.Input("model"), |
| 157 | + io.Float.Input("ratio", default=0.3, min=0.0, max=1.0, step=0.01), |
| 158 | + ], |
| 159 | + outputs=[io.Model.Output()], |
| 160 | + ) |
155 | 161 |
|
156 | | - CATEGORY = "model_patches/unet" |
157 | | - |
158 | | - def patch(self, model, ratio): |
159 | | - self.u = None |
| 162 | + @classmethod |
| 163 | + def execute(cls, model, ratio) -> io.NodeOutput: |
| 164 | + u: Optional[Callable] = None |
160 | 165 | def tomesd_m(q, k, v, extra_options): |
| 166 | + nonlocal u |
161 | 167 | #NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q |
162 | 168 | #however from my basic testing it seems that using q instead gives better results |
163 | | - m, self.u = get_functions(q, ratio, extra_options["original_shape"]) |
| 169 | + m, u = get_functions(q, ratio, extra_options["original_shape"]) |
164 | 170 | return m(q), k, v |
165 | 171 | def tomesd_u(n, extra_options): |
166 | | - return self.u(n) |
| 172 | + nonlocal u |
| 173 | + return u(n) |
167 | 174 |
|
168 | 175 | m = model.clone() |
169 | 176 | m.set_model_attn1_patch(tomesd_m) |
170 | 177 | m.set_model_attn1_output_patch(tomesd_u) |
171 | | - return (m, ) |
| 178 | + return io.NodeOutput(m) |
| 179 | + |
| 180 | + |
| 181 | +class TomePatchModelExtension(ComfyExtension): |
| 182 | + @override |
| 183 | + async def get_node_list(self) -> list[type[io.ComfyNode]]: |
| 184 | + return [ |
| 185 | + TomePatchModel, |
| 186 | + ] |
172 | 187 |
|
173 | 188 |
|
174 | | -NODE_CLASS_MAPPINGS = { |
175 | | - "TomePatchModel": TomePatchModel, |
176 | | -} |
| 189 | +async def comfy_entrypoint() -> TomePatchModelExtension: |
| 190 | + return TomePatchModelExtension() |
0 commit comments