File tree Expand file tree Collapse file tree 2 files changed +61
-0
lines changed Expand file tree Collapse file tree 2 files changed +61
-0
lines changed Original file line number Diff line number Diff line change 1+ class EpsilonScaling :
2+ """
3+ Implements the Epsilon Scaling method from 'Elucidating the Exposure Bias in Diffusion Models'
4+ (https://arxiv.org/abs/2308.15321v6).
5+
6+ This method mitigates exposure bias by scaling the predicted noise during sampling,
7+ which can significantly improve sample quality. This implementation uses the "uniform schedule"
8+ recommended by the paper for its practicality and effectiveness.
9+ """
10+ @classmethod
11+ def INPUT_TYPES (s ):
12+ return {
13+ "required" : {
14+ "model" : ("MODEL" ,),
15+ "scaling_factor" : ("FLOAT" , {
16+ "default" : 1.005 ,
17+ "min" : 0.5 ,
18+ "max" : 1.5 ,
19+ "step" : 0.001 ,
20+ "display" : "number"
21+ }),
22+ }
23+ }
24+
25+ RETURN_TYPES = ("MODEL" ,)
26+ FUNCTION = "patch"
27+
28+ CATEGORY = "model_patches/unet"
29+
30+ def patch (self , model , scaling_factor ):
31+ # Prevent division by zero, though the UI's min value should prevent this.
32+ if scaling_factor == 0 :
33+ scaling_factor = 1e-9
34+
35+ def epsilon_scaling_function (args ):
36+ """
37+ This function is applied after the CFG guidance has been calculated.
38+ It recalculates the denoised latent by scaling the predicted noise.
39+ """
40+ denoised = args ["denoised" ]
41+ x = args ["input" ]
42+
43+ noise_pred = x - denoised
44+
45+ scaled_noise_pred = noise_pred / scaling_factor
46+
47+ new_denoised = x - scaled_noise_pred
48+
49+ return new_denoised
50+
51+ # Clone the model patcher to avoid modifying the original model in place
52+ model_clone = model .clone ()
53+
54+ model_clone .set_model_sampler_post_cfg_function (epsilon_scaling_function )
55+
56+ return (model_clone ,)
57+
58+ NODE_CLASS_MAPPINGS = {
59+ "Epsilon Scaling" : EpsilonScaling
60+ }
Original file line number Diff line number Diff line change @@ -2297,6 +2297,7 @@ async def init_builtin_extra_nodes():
22972297 "nodes_gits.py" ,
22982298 "nodes_controlnet.py" ,
22992299 "nodes_hunyuan.py" ,
2300+ "nodes_eps.py" ,
23002301 "nodes_flux.py" ,
23012302 "nodes_lora_extract.py" ,
23022303 "nodes_torch_compile.py" ,
You can’t perform that action at this time.
0 commit comments