11# Code based on https://github.com/WikiChao/FreSca (MIT License)
22import torch
33import torch .fft as fft
4+ from typing_extensions import override
5+ from comfy_api .latest import ComfyExtension , io
46
57
68def Fourier_filter (x , scale_low = 1.0 , scale_high = 1.5 , freq_cutoff = 20 ):
@@ -51,25 +53,31 @@ def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20):
5153 return x_filtered
5254
5355
54- class FreSca :
56+ class FreSca ( io . ComfyNode ) :
5557 @classmethod
56- def INPUT_TYPES (s ):
57- return {
58- "required" : {
59- "model" : ("MODEL" ,),
60- "scale_low" : ("FLOAT" , {"default" : 1.0 , "min" : 0 , "max" : 10 , "step" : 0.01 ,
61- "tooltip" : "Scaling factor for low-frequency components" }),
62- "scale_high" : ("FLOAT" , {"default" : 1.25 , "min" : 0 , "max" : 10 , "step" : 0.01 ,
63- "tooltip" : "Scaling factor for high-frequency components" }),
64- "freq_cutoff" : ("INT" , {"default" : 20 , "min" : 1 , "max" : 10000 , "step" : 1 ,
65- "tooltip" : "Number of frequency indices around center to consider as low-frequency" }),
66- }
67- }
68- RETURN_TYPES = ("MODEL" ,)
69- FUNCTION = "patch"
70- CATEGORY = "_for_testing"
71- DESCRIPTION = "Applies frequency-dependent scaling to the guidance"
72- def patch (self , model , scale_low , scale_high , freq_cutoff ):
58+ def define_schema (cls ):
59+ return io .Schema (
60+ node_id = "FreSca" ,
61+ display_name = "FreSca" ,
62+ category = "_for_testing" ,
63+ description = "Applies frequency-dependent scaling to the guidance" ,
64+ inputs = [
65+ io .Model .Input ("model" ),
66+ io .Float .Input ("scale_low" , default = 1.0 , min = 0 , max = 10 , step = 0.01 ,
67+ tooltip = "Scaling factor for low-frequency components" ),
68+ io .Float .Input ("scale_high" , default = 1.25 , min = 0 , max = 10 , step = 0.01 ,
69+ tooltip = "Scaling factor for high-frequency components" ),
70+ io .Int .Input ("freq_cutoff" , default = 20 , min = 1 , max = 10000 , step = 1 ,
71+ tooltip = "Number of frequency indices around center to consider as low-frequency" ),
72+ ],
73+ outputs = [
74+ io .Model .Output (),
75+ ],
76+ is_experimental = True ,
77+ )
78+
79+ @classmethod
80+ def execute (cls , model , scale_low , scale_high , freq_cutoff ):
7381 def custom_cfg_function (args ):
7482 conds_out = args ["conds_out" ]
7583 if len (conds_out ) <= 1 or None in args ["conds" ][:2 ]:
@@ -91,13 +99,16 @@ def custom_cfg_function(args):
9199 m = model .clone ()
92100 m .set_model_sampler_pre_cfg_function (custom_cfg_function )
93101
94- return (m ,)
102+ return io .NodeOutput (m )
103+
95104
105+ class FreScaExtension (ComfyExtension ):
106+ @override
107+ async def get_node_list (self ) -> list [type [io .ComfyNode ]]:
108+ return [
109+ FreSca ,
110+ ]
96111
97- NODE_CLASS_MAPPINGS = {
98- "FreSca" : FreSca ,
99- }
100112
101- NODE_DISPLAY_NAME_MAPPINGS = {
102- "FreSca" : "FreSca" ,
103- }
113+ async def comfy_entrypoint () -> FreScaExtension :
114+ return FreScaExtension ()
0 commit comments