Skip to content

Fix to generate one LoRAAttnProcessor for each CLIPAttention in TextEncoder LoRA #3505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from

Conversation

takuma104
Copy link
Contributor

@takuma104 takuma104 commented May 22, 2023

What's this?

Discussed in PR #3437 , #3437 (comment)

This PR will fix the creation method of the LoRAAttnProcessor in text_encoder, and change it so that one LoRAAttnProcessor is created per CLIPAttention. Change the comment part of the following code to the line below. The current Diffusers code generates four LoRAAttnProcessor for each CLIPAttention, as TEXT_ENCODER_TARGET_MODULES has four keys.

for name, module in text_encoder.named_modules():
    # if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): # current Diffusers
    if name.endswith('self_attn'): # this PR
        print(name)
        text_lora_attn_procs[name] = LoRAAttnProcessor(
            hidden_size=module.out_proj.out_features, cross_attention_dim=None
        )

self_attn is the CLIPAttention class. This is equivalent to the Attention class in Diffusers. I thought it would be appropriate to generate LoRAAttnProcessor in a 1:1 relationship with this. See also: #3437 (comment)

Todo:

- [ ] This PR changes the key of the weight. Compatibility issues arise with already trained checkpoints with the --train_text_encoder option. To accommodate this, handling in the loader will be necessary.

Note:

This PR currently branches from the working branch of #3490 . Therefore, the File Changed tab also includes the changes from #3490 . Please refer this diff for this PR unique commits.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@takuma104
Copy link
Contributor Author

takuma104 commented May 22, 2023

I conducted qualitative tests using the dreambooth/lora script. The script I used is as follows.

export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="lora-trained"

accelerate launch ../examples/dreambooth/train_dreambooth_lora.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --output_dir=$OUTPUT_DIR \
  --instance_prompt="a photo of sks dog" \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --checkpointing_steps=100 \
  --learning_rate=1e-4 \
  --report_to="wandb" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=500 \
  --validation_prompt="A photo of sks dog in a bucket" \
  --validation_epochs=50 \
  --train_text_encoder \
  --seed="0" \
  --push_to_hub

The result checkpoint:
https://huggingface.co/takuma104/lora-test-text-encoder-lora-target

The following is dumped by https://gist.github.com/takuma104/dd7855909af17f3792b3704578c63a26 It seems that some learning has progressed since to_*_lora.up.weight are non-zero.

text_encoder.text_model.encoder.layers.0.self_attn.to_k_lora.down.weight [4, 768] mean=0.000592 std=0.248
text_encoder.text_model.encoder.layers.0.self_attn.to_k_lora.up.weight [768, 4] mean=-0.000118 std=0.00203
text_encoder.text_model.encoder.layers.0.self_attn.to_out_lora.down.weight [4, 768] mean=0.00265 std=0.248
text_encoder.text_model.encoder.layers.0.self_attn.to_out_lora.up.weight [768, 4] mean=7.15e-07 std=0.00217
text_encoder.text_model.encoder.layers.0.self_attn.to_q_lora.down.weight [4, 768] mean=0.00118 std=0.252
text_encoder.text_model.encoder.layers.0.self_attn.to_q_lora.up.weight [768, 4] mean=1.87e-05 std=0.00192
text_encoder.text_model.encoder.layers.0.self_attn.to_v_lora.down.weight [4, 768] mean=0.00242 std=0.251
text_encoder.text_model.encoder.layers.0.self_attn.to_v_lora.up.weight [768, 4] mean=-7.65e-06 std=0.00224

@patrickvonplaten
Copy link
Contributor

Nice, I think this is the missing piece here - great find @takuma104! @sayakpaul can you have a look here?

@@ -943,14 +943,19 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
module = self.text_encoder.get_submodule(name)
# Construct a new function that performs the LoRA merging. We will monkey patch
# this forward pass.
lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name))
attn_processor_name = ".".join(name.split(".")[:-1])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To get the correct mapping in the names as discovered in #3437 (comment)

text_lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=module.out_features, cross_attention_dim=None
hidden_size=module.out_proj.out_features, cross_attention_dim=None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

However, does the following need to be changed since that LoRA layer mapping is now being changed?

self.split_keys = [".processor", ".k_proj", ".q_proj", ".v_proj", ".out_proj"]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch! The current code becomes a problem when the remap_key() function is called, so I've fixed it 160a4d3. However, it seems that the current test code and dreambooth script do not fall into the conditions where this remap_key() function is called, so I haven't been able to test it.

As far as I understand, the intention here is that this part is called when AttnProcsLayers is directly loaded with load_state_dict(), but are there any use cases where this part is called?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Yeah I think so too. @patrickvonplaten could you confirm once?

return text_lora_attn_procs


def create_text_encoder_lora_layers(text_encoder: nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we're not using this method. Okay for me to discard.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've left this as it still seems to be used inside LoraLoaderMixinTests.get_dummy_components().

@sayakpaul
Copy link
Member

I think we can merge this once #3505 (comment) is addressed. I think then we can also merge #3490 as is no?

@sayakpaul
Copy link
Member

An inspiring update: #3437 (comment).

@takuma104 I think for the utilities, we can sum everything up (this PR, #3437, #3490) this week.

Let me know anything that needs attention, testing, etc.

@takuma104
Copy link
Contributor Author

@sayakpaul I have made fixes regarding #3505 (comment). (As I commented, I have not been able to test it)

@takuma104 takuma104 marked this pull request as ready for review May 24, 2023 15:31
@rvorias
Copy link

rvorias commented May 24, 2023

Ran some finetunes with these PR changes. Qualitatively, results are looking promising. Thanks for the efforts!

@sayakpaul
Copy link
Member

Closing in favor of #3437.

@takuma104 takuma104 closed this May 29, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants