11#Taken from: https://github.com/tfernd/HyperTile/
22
33import math
4+ from typing_extensions import override
45from einops import rearrange
56# Use torch rng for consistency across generations
67from torch import randint
8+ from comfy_api .latest import ComfyExtension , io
79
810def random_divisor (value : int , min_value : int , / , max_options : int = 1 ) -> int :
911 min_value = min (min_value , value )
@@ -20,25 +22,31 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
2022
2123 return ns [idx ]
2224
23- class HyperTile :
25+ class HyperTile ( io . ComfyNode ) :
2426 @classmethod
25- def INPUT_TYPES (s ):
26- return {"required" : { "model" : ("MODEL" ,),
27- "tile_size" : ("INT" , {"default" : 256 , "min" : 1 , "max" : 2048 }),
28- "swap_size" : ("INT" , {"default" : 2 , "min" : 1 , "max" : 128 }),
29- "max_depth" : ("INT" , {"default" : 0 , "min" : 0 , "max" : 10 }),
30- "scale_depth" : ("BOOLEAN" , {"default" : False }),
31- }}
32- RETURN_TYPES = ("MODEL" ,)
33- FUNCTION = "patch"
34-
35- CATEGORY = "model_patches/unet"
36-
37- def patch (self , model , tile_size , swap_size , max_depth , scale_depth ):
27+ def define_schema (cls ):
28+ return io .Schema (
29+ node_id = "HyperTile" ,
30+ category = "model_patches/unet" ,
31+ inputs = [
32+ io .Model .Input ("model" ),
33+ io .Int .Input ("tile_size" , default = 256 , min = 1 , max = 2048 ),
34+ io .Int .Input ("swap_size" , default = 2 , min = 1 , max = 128 ),
35+ io .Int .Input ("max_depth" , default = 0 , min = 0 , max = 10 ),
36+ io .Boolean .Input ("scale_depth" , default = False ),
37+ ],
38+ outputs = [
39+ io .Model .Output (),
40+ ],
41+ )
42+
43+ @classmethod
44+ def execute (cls , model , tile_size , swap_size , max_depth , scale_depth ) -> io .NodeOutput :
3845 latent_tile_size = max (32 , tile_size ) // 8
39- self . temp = None
46+ temp = None
4047
4148 def hypertile_in (q , k , v , extra_options ):
49+ nonlocal temp
4250 model_chans = q .shape [- 2 ]
4351 orig_shape = extra_options ['original_shape' ]
4452 apply_to = []
@@ -58,14 +66,15 @@ def hypertile_in(q, k, v, extra_options):
5866
5967 if nh * nw > 1 :
6068 q = rearrange (q , "b (nh h nw w) c -> (b nh nw) (h w) c" , h = h // nh , w = w // nw , nh = nh , nw = nw )
61- self . temp = (nh , nw , h , w )
69+ temp = (nh , nw , h , w )
6270 return q , k , v
6371
6472 return q , k , v
6573 def hypertile_out (out , extra_options ):
66- if self .temp is not None :
67- nh , nw , h , w = self .temp
68- self .temp = None
74+ nonlocal temp
75+ if temp is not None :
76+ nh , nw , h , w = temp
77+ temp = None
6978 out = rearrange (out , "(b nh nw) hw c -> b nh nw hw c" , nh = nh , nw = nw )
7079 out = rearrange (out , "b nh nw (h w) c -> b (nh h nw w) c" , h = h // nh , w = w // nw )
7180 return out
@@ -76,6 +85,14 @@ def hypertile_out(out, extra_options):
7685 m .set_model_attn1_output_patch (hypertile_out )
7786 return (m , )
7887
79- NODE_CLASS_MAPPINGS = {
80- "HyperTile" : HyperTile ,
81- }
88+
89+ class HyperTileExtension (ComfyExtension ):
90+ @override
91+ async def get_node_list (self ) -> list [type [io .ComfyNode ]]:
92+ return [
93+ HyperTile ,
94+ ]
95+
96+
97+ async def comfy_entrypoint () -> HyperTileExtension :
98+ return HyperTileExtension ()
0 commit comments