|
| 1 | +import torch |
1 | 2 | from typing_extensions import override |
2 | 3 |
|
| 4 | +from comfy.k_diffusion.sampling import sigma_to_half_log_snr |
3 | 5 | from comfy_api.latest import ComfyExtension, io |
4 | 6 |
|
5 | 7 |
|
@@ -63,12 +65,105 @@ def epsilon_scaling_function(args): |
63 | 65 | return io.NodeOutput(model_clone) |
64 | 66 |
|
65 | 67 |
|
| 68 | +def compute_tsr_rescaling_factor( |
| 69 | + snr: torch.Tensor, tsr_k: float, tsr_variance: float |
| 70 | +) -> torch.Tensor: |
| 71 | + """Compute the rescaling score ratio in Temporal Score Rescaling. |
| 72 | +
|
| 73 | + See equation (6) in https://arxiv.org/pdf/2510.01184v1. |
| 74 | + """ |
| 75 | + posinf_mask = torch.isposinf(snr) |
| 76 | + rescaling_factor = (snr * tsr_variance + 1) / (snr * tsr_variance / tsr_k + 1) |
| 77 | + return torch.where(posinf_mask, tsr_k, rescaling_factor) # when snr → inf, r = tsr_k |
| 78 | + |
| 79 | + |
| 80 | +class TemporalScoreRescaling(io.ComfyNode): |
| 81 | + @classmethod |
| 82 | + def define_schema(cls): |
| 83 | + return io.Schema( |
| 84 | + node_id="TemporalScoreRescaling", |
| 85 | + display_name="TSR - Temporal Score Rescaling", |
| 86 | + category="model_patches/unet", |
| 87 | + inputs=[ |
| 88 | + io.Model.Input("model"), |
| 89 | + io.Float.Input( |
| 90 | + "tsr_k", |
| 91 | + tooltip=( |
| 92 | + "Controls the rescaling strength.\n" |
| 93 | + "Lower k produces more detailed results; higher k produces smoother results in image generation. Setting k = 1 disables rescaling." |
| 94 | + ), |
| 95 | + default=0.95, |
| 96 | + min=0.01, |
| 97 | + max=100.0, |
| 98 | + step=0.001, |
| 99 | + display_mode=io.NumberDisplay.number, |
| 100 | + ), |
| 101 | + io.Float.Input( |
| 102 | + "tsr_sigma", |
| 103 | + tooltip=( |
| 104 | + "Controls how early rescaling takes effect.\n" |
| 105 | + "Larger values take effect earlier." |
| 106 | + ), |
| 107 | + default=1.0, |
| 108 | + min=0.01, |
| 109 | + max=100.0, |
| 110 | + step=0.001, |
| 111 | + display_mode=io.NumberDisplay.number, |
| 112 | + ), |
| 113 | + ], |
| 114 | + outputs=[ |
| 115 | + io.Model.Output( |
| 116 | + display_name="patched_model", |
| 117 | + ), |
| 118 | + ], |
| 119 | + description=( |
| 120 | + "[Post-CFG Function]\n" |
| 121 | + "TSR - Temporal Score Rescaling (2510.01184)\n\n" |
| 122 | + "Rescaling the model's score or noise to steer the sampling diversity.\n" |
| 123 | + ), |
| 124 | + ) |
| 125 | + |
| 126 | + @classmethod |
| 127 | + def execute(cls, model, tsr_k, tsr_sigma) -> io.NodeOutput: |
| 128 | + tsr_variance = tsr_sigma**2 |
| 129 | + |
| 130 | + def temporal_score_rescaling(args): |
| 131 | + denoised = args["denoised"] |
| 132 | + x = args["input"] |
| 133 | + sigma = args["sigma"] |
| 134 | + curr_model = args["model"] |
| 135 | + |
| 136 | + # No rescaling (r = 1) or no noise |
| 137 | + if tsr_k == 1 or sigma == 0: |
| 138 | + return denoised |
| 139 | + |
| 140 | + model_sampling = curr_model.current_patcher.get_model_object("model_sampling") |
| 141 | + half_log_snr = sigma_to_half_log_snr(sigma, model_sampling) |
| 142 | + snr = (2 * half_log_snr).exp() |
| 143 | + |
| 144 | + # No rescaling needed (r = 1) |
| 145 | + if snr == 0: |
| 146 | + return denoised |
| 147 | + |
| 148 | + rescaling_r = compute_tsr_rescaling_factor(snr, tsr_k, tsr_variance) |
| 149 | + |
| 150 | + # Derived from scaled_denoised = (x - r * sigma * noise) / alpha |
| 151 | + alpha = sigma * half_log_snr.exp() |
| 152 | + return torch.lerp(x / alpha, denoised, rescaling_r) |
| 153 | + |
| 154 | + m = model.clone() |
| 155 | + m.set_model_sampler_post_cfg_function(temporal_score_rescaling) |
| 156 | + return io.NodeOutput(m) |
| 157 | + |
| 158 | + |
66 | 159 | class EpsilonScalingExtension(ComfyExtension): |
67 | 160 | @override |
68 | 161 | async def get_node_list(self) -> list[type[io.ComfyNode]]: |
69 | 162 | return [ |
70 | 163 | EpsilonScaling, |
| 164 | + TemporalScoreRescaling, |
71 | 165 | ] |
72 | 166 |
|
| 167 | + |
73 | 168 | async def comfy_entrypoint() -> EpsilonScalingExtension: |
74 | 169 | return EpsilonScalingExtension() |
0 commit comments