Skip to content

Commit 681dd46

Browse files
chaObservadlerfaulkner
authored andcommitted
Add TemporalScoreRescaling node (comfyanonymous#10351)
* Add TemporalScoreRescaling node * Mention image generation in tsr_k's tooltip
1 parent 81c3fc7 commit 681dd46

File tree

1 file changed

+95
-0
lines changed

1 file changed

+95
-0
lines changed

comfy_extras/nodes_eps.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import torch
12
from typing_extensions import override
23

4+
from comfy.k_diffusion.sampling import sigma_to_half_log_snr
35
from comfy_api.latest import ComfyExtension, io
46

57

@@ -63,12 +65,105 @@ def epsilon_scaling_function(args):
6365
return io.NodeOutput(model_clone)
6466

6567

68+
def compute_tsr_rescaling_factor(
69+
snr: torch.Tensor, tsr_k: float, tsr_variance: float
70+
) -> torch.Tensor:
71+
"""Compute the rescaling score ratio in Temporal Score Rescaling.
72+
73+
See equation (6) in https://arxiv.org/pdf/2510.01184v1.
74+
"""
75+
posinf_mask = torch.isposinf(snr)
76+
rescaling_factor = (snr * tsr_variance + 1) / (snr * tsr_variance / tsr_k + 1)
77+
return torch.where(posinf_mask, tsr_k, rescaling_factor) # when snr → inf, r = tsr_k
78+
79+
80+
class TemporalScoreRescaling(io.ComfyNode):
81+
@classmethod
82+
def define_schema(cls):
83+
return io.Schema(
84+
node_id="TemporalScoreRescaling",
85+
display_name="TSR - Temporal Score Rescaling",
86+
category="model_patches/unet",
87+
inputs=[
88+
io.Model.Input("model"),
89+
io.Float.Input(
90+
"tsr_k",
91+
tooltip=(
92+
"Controls the rescaling strength.\n"
93+
"Lower k produces more detailed results; higher k produces smoother results in image generation. Setting k = 1 disables rescaling."
94+
),
95+
default=0.95,
96+
min=0.01,
97+
max=100.0,
98+
step=0.001,
99+
display_mode=io.NumberDisplay.number,
100+
),
101+
io.Float.Input(
102+
"tsr_sigma",
103+
tooltip=(
104+
"Controls how early rescaling takes effect.\n"
105+
"Larger values take effect earlier."
106+
),
107+
default=1.0,
108+
min=0.01,
109+
max=100.0,
110+
step=0.001,
111+
display_mode=io.NumberDisplay.number,
112+
),
113+
],
114+
outputs=[
115+
io.Model.Output(
116+
display_name="patched_model",
117+
),
118+
],
119+
description=(
120+
"[Post-CFG Function]\n"
121+
"TSR - Temporal Score Rescaling (2510.01184)\n\n"
122+
"Rescaling the model's score or noise to steer the sampling diversity.\n"
123+
),
124+
)
125+
126+
@classmethod
127+
def execute(cls, model, tsr_k, tsr_sigma) -> io.NodeOutput:
128+
tsr_variance = tsr_sigma**2
129+
130+
def temporal_score_rescaling(args):
131+
denoised = args["denoised"]
132+
x = args["input"]
133+
sigma = args["sigma"]
134+
curr_model = args["model"]
135+
136+
# No rescaling (r = 1) or no noise
137+
if tsr_k == 1 or sigma == 0:
138+
return denoised
139+
140+
model_sampling = curr_model.current_patcher.get_model_object("model_sampling")
141+
half_log_snr = sigma_to_half_log_snr(sigma, model_sampling)
142+
snr = (2 * half_log_snr).exp()
143+
144+
# No rescaling needed (r = 1)
145+
if snr == 0:
146+
return denoised
147+
148+
rescaling_r = compute_tsr_rescaling_factor(snr, tsr_k, tsr_variance)
149+
150+
# Derived from scaled_denoised = (x - r * sigma * noise) / alpha
151+
alpha = sigma * half_log_snr.exp()
152+
return torch.lerp(x / alpha, denoised, rescaling_r)
153+
154+
m = model.clone()
155+
m.set_model_sampler_post_cfg_function(temporal_score_rescaling)
156+
return io.NodeOutput(m)
157+
158+
66159
class EpsilonScalingExtension(ComfyExtension):
67160
@override
68161
async def get_node_list(self) -> list[type[io.ComfyNode]]:
69162
return [
70163
EpsilonScaling,
164+
TemporalScoreRescaling,
71165
]
72166

167+
73168
async def comfy_entrypoint() -> EpsilonScalingExtension:
74169
return EpsilonScalingExtension()

0 commit comments

Comments
 (0)