-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Refactor LoRA #3778
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor LoRA #3778
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ | |
import shutil | ||
import warnings | ||
from pathlib import Path | ||
from typing import Dict | ||
|
||
import numpy as np | ||
import torch | ||
|
@@ -50,7 +51,10 @@ | |
StableDiffusionPipeline, | ||
UNet2DConditionModel, | ||
) | ||
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin | ||
from diffusers.loaders import ( | ||
LoraLoaderMixin, | ||
text_encoder_lora_state_dict, | ||
) | ||
from diffusers.models.attention_processor import ( | ||
AttnAddedKVProcessor, | ||
AttnAddedKVProcessor2_0, | ||
|
@@ -60,7 +64,7 @@ | |
SlicedAttnAddedKVProcessor, | ||
) | ||
from diffusers.optimization import get_scheduler | ||
from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available | ||
from diffusers.utils import check_min_version, is_wandb_available | ||
from diffusers.utils.import_utils import is_xformers_available | ||
|
||
|
||
|
@@ -653,6 +657,22 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte | |
return prompt_embeds | ||
|
||
|
||
def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]: | ||
r""" | ||
Returns: | ||
a state dict containing just the attention processor parameters. | ||
""" | ||
attn_processors = unet.attn_processors | ||
|
||
attn_processors_state_dict = {} | ||
|
||
for attn_processor_key, attn_processor in attn_processors.items(): | ||
for parameter_key, parameter in attn_processor.state_dict().items(): | ||
attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter | ||
|
||
return attn_processors_state_dict | ||
|
||
|
||
def main(args): | ||
logging_dir = Path(args.output_dir, args.logging_dir) | ||
|
||
|
@@ -833,6 +853,7 @@ def main(args): | |
|
||
# Set correct lora layers | ||
unet_lora_attn_procs = {} | ||
unet_lora_parameters = [] | ||
for name, attn_processor in unet.attn_processors.items(): | ||
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim | ||
if name.startswith("mid_block"): | ||
|
@@ -850,35 +871,18 @@ def main(args): | |
lora_attn_processor_class = ( | ||
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor | ||
) | ||
unet_lora_attn_procs[name] = lora_attn_processor_class( | ||
hidden_size=hidden_size, | ||
cross_attention_dim=cross_attention_dim, | ||
rank=args.rank, | ||
) | ||
|
||
module = lora_attn_processor_class(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) | ||
unet_lora_attn_procs[name] = module | ||
unet_lora_parameters.extend(module.parameters()) | ||
|
||
unet.set_attn_processor(unet_lora_attn_procs) | ||
unet_lora_layers = AttnProcsLayers(unet.attn_processors) | ||
|
||
# The text encoder comes from 🤗 transformers, so we cannot directly modify it. | ||
# So, instead, we monkey-patch the forward calls of its attention-blocks. For this, | ||
# we first load a dummy pipeline with the text encoder and then do the monkey-patching. | ||
text_encoder_lora_layers = None | ||
# So, instead, we monkey-patch the forward calls of its attention-blocks. | ||
if args.train_text_encoder: | ||
text_lora_attn_procs = {} | ||
for name, module in text_encoder.named_modules(): | ||
if name.endswith(TEXT_ENCODER_ATTN_MODULE): | ||
text_lora_attn_procs[name] = LoRAAttnProcessor( | ||
hidden_size=module.out_proj.out_features, | ||
cross_attention_dim=None, | ||
rank=args.rank, | ||
) | ||
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) | ||
temp_pipeline = DiffusionPipeline.from_pretrained( | ||
args.pretrained_model_name_or_path, text_encoder=text_encoder | ||
) | ||
temp_pipeline._modify_text_encoder(text_lora_attn_procs) | ||
text_encoder = temp_pipeline.text_encoder | ||
del temp_pipeline | ||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 | ||
text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder, dtype=torch.float32) | ||
|
||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format | ||
def save_model_hook(models, weights, output_dir): | ||
|
@@ -887,23 +891,13 @@ def save_model_hook(models, weights, output_dir): | |
unet_lora_layers_to_save = None | ||
text_encoder_lora_layers_to_save = None | ||
|
||
if args.train_text_encoder: | ||
text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys() | ||
unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys() | ||
|
||
for model in models: | ||
state_dict = model.state_dict() | ||
|
||
if ( | ||
text_encoder_lora_layers is not None | ||
and text_encoder_keys is not None | ||
and state_dict.keys() == text_encoder_keys | ||
): | ||
# text encoder | ||
text_encoder_lora_layers_to_save = state_dict | ||
elif state_dict.keys() == unet_keys: | ||
# unet | ||
unet_lora_layers_to_save = state_dict | ||
if isinstance(model, type(accelerator.unwrap_model(unet))): | ||
unet_lora_layers_to_save = unet_attn_processors_state_dict(model) | ||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))): | ||
text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model) | ||
else: | ||
raise ValueError(f"unexpected save model: {model.__class__}") | ||
|
||
# make sure to pop weight so that corresponding model is not saved again | ||
weights.pop() | ||
|
@@ -915,27 +909,24 @@ def save_model_hook(models, weights, output_dir): | |
) | ||
|
||
def load_model_hook(models, input_dir): | ||
# Note we DON'T pass the unet and text encoder here an purpose | ||
# so that the we don't accidentally override the LoRA layers of | ||
# unet_lora_layers and text_encoder_lora_layers which are stored in `models` | ||
# with new torch.nn.Modules / weights. We simply use the pipeline class as | ||
# an easy way to load the lora checkpoints | ||
temp_pipeline = DiffusionPipeline.from_pretrained( | ||
args.pretrained_model_name_or_path, | ||
revision=args.revision, | ||
torch_dtype=weight_dtype, | ||
) | ||
temp_pipeline.load_lora_weights(input_dir) | ||
unet_ = None | ||
text_encoder_ = None | ||
|
||
# load lora weights into models | ||
models[0].load_state_dict(AttnProcsLayers(temp_pipeline.unet.attn_processors).state_dict()) | ||
if len(models) > 1: | ||
models[1].load_state_dict(AttnProcsLayers(temp_pipeline.text_encoder_lora_attn_procs).state_dict()) | ||
while len(models) > 0: | ||
model = models.pop() | ||
|
||
# delete temporary pipeline and pop models | ||
del temp_pipeline | ||
for _ in range(len(models)): | ||
models.pop() | ||
if isinstance(model, type(accelerator.unwrap_model(unet))): | ||
unet_ = model | ||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))): | ||
text_encoder_ = model | ||
else: | ||
raise ValueError(f"unexpected save model: {model.__class__}") | ||
|
||
lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir) | ||
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_) | ||
LoraLoaderMixin.load_lora_into_text_encoder( | ||
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_ | ||
) | ||
|
||
accelerator.register_save_state_pre_hook(save_model_hook) | ||
accelerator.register_load_state_pre_hook(load_model_hook) | ||
|
@@ -965,9 +956,9 @@ def load_model_hook(models, input_dir): | |
|
||
# Optimizer creation | ||
params_to_optimize = ( | ||
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters()) | ||
itertools.chain(unet_lora_parameters, text_lora_parameters) | ||
if args.train_text_encoder | ||
else unet_lora_layers.parameters() | ||
else unet_lora_parameters | ||
) | ||
optimizer = optimizer_class( | ||
params_to_optimize, | ||
|
@@ -1056,12 +1047,12 @@ def compute_text_embeddings(prompt): | |
|
||
# Prepare everything with our `accelerator`. | ||
if args.train_text_encoder: | ||
unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( | ||
unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler | ||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( | ||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler | ||
) | ||
else: | ||
unet_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( | ||
unet_lora_layers, optimizer, train_dataloader, lr_scheduler | ||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( | ||
unet, optimizer, train_dataloader, lr_scheduler | ||
) | ||
|
||
# We need to recalculate our total training steps as the size of the training dataloader may have changed. | ||
|
@@ -1210,9 +1201,9 @@ def compute_text_embeddings(prompt): | |
accelerator.backward(loss) | ||
if accelerator.sync_gradients: | ||
params_to_clip = ( | ||
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters()) | ||
itertools.chain(unet_lora_parameters, text_lora_parameters) | ||
if args.train_text_encoder | ||
else unet_lora_layers.parameters() | ||
else unet_lora_parameters | ||
) | ||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | ||
optimizer.step() | ||
|
@@ -1301,15 +1292,17 @@ def compute_text_embeddings(prompt): | |
pipeline_args = {"prompt": args.validation_prompt} | ||
|
||
if args.validation_images is None: | ||
images = [ | ||
pipeline(**pipeline_args, generator=generator).images[0] | ||
for _ in range(args.num_validation_images) | ||
] | ||
images = [] | ||
for _ in range(args.num_validation_images): | ||
with torch.cuda.amp.autocast(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why these changes? We try to avoid running pipelines in autocast if possible There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I copy and pasted from the regular dreambooth training script which runs the validation inference under autocast. Will remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually I might have added this because when we load the rest of the model in fp16, we keep the lora weights in fp32 and needed autocast for them to work together There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, looking through the commit history, that's why I added it. Is that ok? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For running inference validation in intervals, I think keeping autocast is okay as it helps keep things simple. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why didn't we need it before? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok so I was wrong that we needed it because of the difference in dtype in the lora weights. The lora layer casts manually internally. The issue was the dtype of the output of the unet being passed to the vae. The difference is that in the In the branch, the full unet is wrapped in amp so even though the unet is loaded in fp16, the output is fp32 and then there's an error when the fp32 ending latents are passed to the fp16 vae. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The alternative to wrapping this section in amp is to put a check either in the pipeline or the beginning of the vae for the dtype of the initial latents and manually cast them if necessary. I like to avoid manual casts like that in case the caller expects the execution in the dtype of their input (which is a reasonable assumption imo). So I would prefer to leave the amp decorator in this case |
||
image = pipeline(**pipeline_args, generator=generator).images[0] | ||
images.append(image) | ||
else: | ||
images = [] | ||
for image in args.validation_images: | ||
image = Image.open(image) | ||
image = pipeline(**pipeline_args, image=image, generator=generator).images[0] | ||
with torch.cuda.amp.autocast(): | ||
image = pipeline(**pipeline_args, image=image, generator=generator).images[0] | ||
images.append(image) | ||
|
||
for tracker in accelerator.trackers: | ||
|
@@ -1332,12 +1325,16 @@ def compute_text_embeddings(prompt): | |
# Save the lora layers | ||
accelerator.wait_for_everyone() | ||
if accelerator.is_main_process: | ||
unet = accelerator.unwrap_model(unet) | ||
unet = unet.to(torch.float32) | ||
unet_lora_layers = accelerator.unwrap_model(unet_lora_layers) | ||
unet_lora_layers = unet_attn_processors_state_dict(unet) | ||
|
||
if text_encoder is not None: | ||
if text_encoder is not None and args.train_text_encoder: | ||
text_encoder = accelerator.unwrap_model(text_encoder) | ||
text_encoder = text_encoder.to(torch.float32) | ||
text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers) | ||
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder) | ||
else: | ||
text_encoder_lora_layers = None | ||
|
||
LoraLoaderMixin.save_lora_weights( | ||
save_directory=args.output_dir, | ||
|
Uh oh!
There was an error while loading. Please reload this page.