-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Make ControlNet SD Training Script torch.compile compatible
#6525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
sayakpaul
merged 4 commits into
huggingface:main
from
suvadityamuk:controlnet_torch_compile
Jan 12, 2024
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
a364306
update: make controlnet script torch compile compatible
suvadityamuk d3215a8
update: correct earlier mistakes for compilation
suvadityamuk 75596e4
update: fix code style issues
suvadityamuk 6e02b0b
Merge branch 'main' into controlnet_torch_compile
suvadityamuk File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,6 +50,7 @@ | |
| from diffusers.optimization import get_scheduler | ||
| from diffusers.utils import check_min_version, is_wandb_available | ||
| from diffusers.utils.import_utils import is_xformers_available | ||
| from diffusers.utils.torch_utils import is_compiled_module | ||
|
|
||
|
|
||
| if is_wandb_available(): | ||
|
|
@@ -787,6 +788,12 @@ def main(args): | |
| logger.info("Initializing controlnet weights from unet") | ||
| controlnet = ControlNetModel.from_unet(unet) | ||
|
|
||
| # Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files) | ||
| def unwrap_model(model): | ||
| model = accelerator.unwrap_model(model) | ||
| model = model._orig_mod if is_compiled_module(model) else model | ||
| return model | ||
|
|
||
| # `accelerate` 0.16.0 will have better support for customized saving | ||
| if version.parse(accelerate.__version__) >= version.parse("0.16.0"): | ||
| # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format | ||
|
|
@@ -846,9 +853,9 @@ def load_model_hook(models, input_dir): | |
| " doing mixed precision training, copy of the weights should still be float32." | ||
| ) | ||
|
|
||
| if accelerator.unwrap_model(controlnet).dtype != torch.float32: | ||
| if unwrap_model(controlnet).dtype != torch.float32: | ||
| raise ValueError( | ||
| f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}" | ||
| f"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}" | ||
| ) | ||
|
|
||
| # Enable TF32 for faster training on Ampere GPUs, | ||
|
|
@@ -1015,7 +1022,7 @@ def load_model_hook(models, input_dir): | |
| noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
|
|
||
| # Get the text embedding for conditioning | ||
| encoder_hidden_states = text_encoder(batch["input_ids"])[0] | ||
| encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: For ControlNet training on SD, it won't matter (same for SDXL) as we never train the text encoder during ControlNet training. But keeping it this way doesn't hurt things. So, I am okay with it. |
||
|
|
||
| controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype) | ||
|
|
||
|
|
@@ -1036,7 +1043,8 @@ def load_model_hook(models, input_dir): | |
| sample.to(dtype=weight_dtype) for sample in down_block_res_samples | ||
| ], | ||
| mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), | ||
| ).sample | ||
| return_dict=False, | ||
| )[0] | ||
|
|
||
| # Get the target for loss depending on the prediction type | ||
| if noise_scheduler.config.prediction_type == "epsilon": | ||
|
|
@@ -1109,7 +1117,7 @@ def load_model_hook(models, input_dir): | |
| # Create the pipeline using using the trained modules and save it. | ||
| accelerator.wait_for_everyone() | ||
| if accelerator.is_main_process: | ||
| controlnet = accelerator.unwrap_model(controlnet) | ||
| controlnet = unwrap_model(controlnet) | ||
| controlnet.save_pretrained(args.output_dir) | ||
|
|
||
| if args.push_to_hub: | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove it. Not a problem. But I appreciate the thoughtfulness.