@@ -138,6 +138,7 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod
138138 else :
139139 operations = model_config .custom_operations
140140 self .diffusion_model = unet_model (** unet_config , device = device , operations = operations )
141+ self .diffusion_model .eval ()
141142 if comfy .model_management .force_channels_last ():
142143 self .diffusion_model .to (memory_format = torch .channels_last )
143144 logging .debug ("using channels last mode for diffusion model" )
@@ -669,7 +670,6 @@ def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG, device=None):
669670class StableCascade_C (BaseModel ):
670671 def __init__ (self , model_config , model_type = ModelType .STABLE_CASCADE , device = None ):
671672 super ().__init__ (model_config , model_type , device = device , unet_model = StageC )
672- self .diffusion_model .eval ().requires_grad_ (False )
673673
674674 def extra_conds (self , ** kwargs ):
675675 out = {}
@@ -698,7 +698,6 @@ def extra_conds(self, **kwargs):
698698class StableCascade_B (BaseModel ):
699699 def __init__ (self , model_config , model_type = ModelType .STABLE_CASCADE , device = None ):
700700 super ().__init__ (model_config , model_type , device = device , unet_model = StageB )
701- self .diffusion_model .eval ().requires_grad_ (False )
702701
703702 def extra_conds (self , ** kwargs ):
704703 out = {}
0 commit comments