-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Fix to generate one LoRAAttnProcessor for each CLIPAttention in TextEncoder LoRA #3505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
I conducted qualitative tests using the dreambooth/lora script. The script I used is as follows. export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="lora-trained"
accelerate launch ../examples/dreambooth/train_dreambooth_lora.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--instance_prompt="a photo of sks dog" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--checkpointing_steps=100 \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=50 \
--train_text_encoder \
--seed="0" \
--push_to_hub The result checkpoint: The following is dumped by https://gist.github.com/takuma104/dd7855909af17f3792b3704578c63a26 It seems that some learning has progressed since
|
Nice, I think this is the missing piece here - great find @takuma104! @sayakpaul can you have a look here? |
@@ -943,14 +943,19 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): | |||
module = self.text_encoder.get_submodule(name) | |||
# Construct a new function that performs the LoRA merging. We will monkey patch | |||
# this forward pass. | |||
lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) | |||
attn_processor_name = ".".join(name.split(".")[:-1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To get the correct mapping in the names as discovered in #3437 (comment)
text_lora_attn_procs[name] = LoRAAttnProcessor( | ||
hidden_size=module.out_features, cross_attention_dim=None | ||
hidden_size=module.out_proj.out_features, cross_attention_dim=None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
However, does the following need to be changed since that LoRA layer mapping is now being changed?
diffusers/src/diffusers/loaders.py
Line 74 in 2f997f3
self.split_keys = [".processor", ".k_proj", ".q_proj", ".v_proj", ".out_proj"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cc: @patrickvonplaten do we need to change the split keys here for the text encoder since https://github.com/huggingface/diffusers/pull/3505/files#diff-eca8763d65ac395bf286dd84b25abcf5e299a737508155270a648534d781e865R34
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch! The current code becomes a problem when the remap_key()
function is called, so I've fixed it 160a4d3. However, it seems that the current test code and dreambooth script do not fall into the conditions where this remap_key()
function is called, so I haven't been able to test it.
As far as I understand, the intention here is that this part is called when AttnProcsLayers
is directly loaded with load_state_dict()
, but are there any use cases where this part is called?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Yeah I think so too. @patrickvonplaten could you confirm once?
return text_lora_attn_procs | ||
|
||
|
||
def create_text_encoder_lora_layers(text_encoder: nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like we're not using this method. Okay for me to discard.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've left this as it still seems to be used inside LoraLoaderMixinTests.get_dummy_components()
.
I think we can merge this once #3505 (comment) is addressed. I think then we can also merge #3490 as is no? |
An inspiring update: #3437 (comment). @takuma104 I think for the utilities, we can sum everything up (this PR, #3437, #3490) this week. Let me know anything that needs attention, testing, etc. |
@sayakpaul I have made fixes regarding #3505 (comment). (As I commented, I have not been able to test it) |
Ran some finetunes with these PR changes. Qualitatively, results are looking promising. Thanks for the efforts! |
Closing in favor of #3437. |
What's this?
Discussed in PR #3437 , #3437 (comment)
This PR will fix the creation method of the
LoRAAttnProcessor
in text_encoder, and change it so that oneLoRAAttnProcessor
is created perCLIPAttention
. Change the comment part of the following code to the line below. The current Diffusers code generates fourLoRAAttnProcessor
for eachCLIPAttention
, asTEXT_ENCODER_TARGET_MODULES
has four keys.self_attn
is theCLIPAttention
class. This is equivalent to theAttention
class in Diffusers. I thought it would be appropriate to generateLoRAAttnProcessor
in a 1:1 relationship with this. See also: #3437 (comment)Todo:
- [ ] This PR changes the key of the weight. Compatibility issues arise with already trained checkpoints with the--train_text_encoder
option. To accommodate this, handling in the loader will be necessary.Note:
This PR currently branches from the working branch of #3490 . Therefore, the File Changed tab also includes the changes from #3490 . Please refer this diff for this PR unique commits.