Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions comfy_extras/nodes_eps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
class EpsilonScaling:
"""
Implements the Epsilon Scaling method from 'Elucidating the Exposure Bias in Diffusion Models'
(https://arxiv.org/abs/2308.15321v6).

This method mitigates exposure bias by scaling the predicted noise during sampling,
which can significantly improve sample quality. This implementation uses the "uniform schedule"
recommended by the paper for its practicality and effectiveness.
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"scaling_factor": ("FLOAT", {
"default": 1.005,
"min": 0.5,
"max": 1.5,
"step": 0.001,
"display": "number"
}),
}
}

RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"

CATEGORY = "model_patches/unet"

def patch(self, model, scaling_factor):
# Prevent division by zero, though the UI's min value should prevent this.
if scaling_factor == 0:
scaling_factor = 1e-9

def epsilon_scaling_function(args):
"""
This function is applied after the CFG guidance has been calculated.
It recalculates the denoised latent by scaling the predicted noise.
"""
denoised = args["denoised"]
x = args["input"]

noise_pred = x - denoised

scaled_noise_pred = noise_pred / scaling_factor

new_denoised = x - scaled_noise_pred

return new_denoised

# Clone the model patcher to avoid modifying the original model in place
model_clone = model.clone()

model_clone.set_model_sampler_post_cfg_function(epsilon_scaling_function)

return (model_clone,)

NODE_CLASS_MAPPINGS = {
"Epsilon Scaling": EpsilonScaling
}
1 change: 1 addition & 0 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2297,6 +2297,7 @@ async def init_builtin_extra_nodes():
"nodes_gits.py",
"nodes_controlnet.py",
"nodes_hunyuan.py",
"nodes_eps.py",
"nodes_flux.py",
"nodes_lora_extract.py",
"nodes_torch_compile.py",
Expand Down