Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/text_to_image/train_text_to_image_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,7 +1084,7 @@ def unwrap_model(model):

# Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps).to(dtype=weight_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do this conditionally if the dtype of noisy_model_input doesn't match with vae.dtype?

Copy link
Contributor Author

@leisuzz leisuzz Aug 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your feedback, actually I tested several times, and I found the dtype error actually happens in:

model_pred = unet(
                    noisy_model_input,
                    timesteps,
                    prompt_embeds,
                    added_cond_kwargs=unet_added_conditions,
                    return_dict=False,
                )[0]

So I think it may not affect the vae. Further, the model_pred will be model_pred.float() in the loss part.


# time ids
def compute_time_ids(original_size, crops_coords_top_left):
Expand All @@ -1101,7 +1101,7 @@ def compute_time_ids(original_size, crops_coords_top_left):

# Predict the noise residual
unet_added_conditions = {"time_ids": add_time_ids}
prompt_embeds = batch["prompt_embeds"].to(accelerator.device)
prompt_embeds = batch["prompt_embeds"].to(accelerator.device, dtype=weight_dtype)
pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device)
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
model_pred = unet(
Expand Down