Skip to content

Commit bc108e1

Browse files
AmericanPresidentJimmyCarterJimmysayakpaulyiyixuxu
authored
Fix DREAM training (#8302)
Co-authored-by: Jimmy <39@🇺🇸.com> Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: YiYi Xu <[email protected]>
1 parent 86555c9 commit bc108e1

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/diffusers/training_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,19 +157,19 @@ def compute_dream_and_update_latents(
157157
with torch.no_grad():
158158
pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
159159

160-
noisy_latents, target = (None, None)
160+
_noisy_latents, _target = (None, None)
161161
if noise_scheduler.config.prediction_type == "epsilon":
162162
predicted_noise = pred
163163
delta_noise = (noise - predicted_noise).detach()
164164
delta_noise.mul_(dream_lambda)
165-
noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise)
166-
target = target.add(delta_noise)
165+
_noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise)
166+
_target = target.add(delta_noise)
167167
elif noise_scheduler.config.prediction_type == "v_prediction":
168168
raise NotImplementedError("DREAM has not been implemented for v-prediction")
169169
else:
170170
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
171171

172-
return noisy_latents, target
172+
return _noisy_latents, _target
173173

174174

175175
def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:

0 commit comments

Comments
 (0)