-
Notifications
You must be signed in to change notification settings - Fork 6.2k
Fix to generate one LoRAAttnProcessor for each CLIPAttention in TextEncoder LoRA #3505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fb708fb
6e8f3ab
8511755
81915f4
88db546
1da772b
8c0926c
8a26848
28c69ee
3a74c7e
160a4d3
f14329d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,8 +70,8 @@ def __init__(self, state_dict: Dict[str, torch.Tensor]): | |
self.mapping = dict(enumerate(state_dict.keys())) | ||
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} | ||
|
||
# .processor for unet, .k_proj, ".q_proj", ".v_proj", and ".out_proj" for text encoder | ||
self.split_keys = [".processor", ".k_proj", ".q_proj", ".v_proj", ".out_proj"] | ||
# .processor for unet, .self_attn for text encoder | ||
self.split_keys = [".processor", ".self_attn"] | ||
|
||
# we add a hook to state_dict() and load_state_dict() so that the | ||
# naming fits with `unet.attn_processors` | ||
|
@@ -943,14 +943,20 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): | |
module = self.text_encoder.get_submodule(name) | ||
# Construct a new function that performs the LoRA merging. We will monkey patch | ||
# this forward pass. | ||
lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) | ||
attn_processor_name = ".".join(name.split(".")[:-1]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To get the correct mapping in the names as discovered in #3437 (comment) |
||
lora_layer = getattr(attn_processors[attn_processor_name], self._get_lora_layer_attribute(name)) | ||
old_forward = module.forward | ||
|
||
def new_forward(x): | ||
return old_forward(x) + lora_layer(x) | ||
# create a new scope that locks in the old_forward, lora_layer value for each new_forward function | ||
# for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060 | ||
def make_new_forward(old_forward, lora_layer): | ||
def new_forward(x): | ||
return old_forward(x) + lora_layer(x) | ||
|
||
return new_forward | ||
|
||
# Monkey-patch. | ||
module.forward = new_forward | ||
module.forward = make_new_forward(old_forward, lora_layer) | ||
|
||
def _get_lora_layer_attribute(self, name: str) -> str: | ||
if "q_proj" in name: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import gc | ||
import os | ||
import tempfile | ||
import unittest | ||
|
@@ -23,7 +24,7 @@ | |
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel | ||
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin | ||
from diffusers.models.attention_processor import LoRAAttnProcessor | ||
from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, floats_tensor, torch_device | ||
from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, floats_tensor, torch_device | ||
|
||
|
||
def create_unet_lora_layers(unet: nn.Module): | ||
|
@@ -43,15 +44,35 @@ def create_unet_lora_layers(unet: nn.Module): | |
return lora_attn_procs, unet_lora_layers | ||
|
||
|
||
def create_text_encoder_lora_layers(text_encoder: nn.Module): | ||
def create_text_encoder_lora_attn_procs(text_encoder: nn.Module): | ||
text_lora_attn_procs = {} | ||
for name, module in text_encoder.named_modules(): | ||
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): | ||
text_lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None) | ||
if name.endswith(TEXT_ENCODER_ATTN_MODULE): | ||
text_lora_attn_procs[name] = LoRAAttnProcessor( | ||
hidden_size=module.out_proj.out_features, cross_attention_dim=None | ||
) | ||
return text_lora_attn_procs | ||
|
||
|
||
def create_text_encoder_lora_layers(text_encoder: nn.Module): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like we're not using this method. Okay for me to discard. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've left this as it still seems to be used inside |
||
text_lora_attn_procs = create_text_encoder_lora_attn_procs(text_encoder) | ||
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) | ||
return text_encoder_lora_layers | ||
|
||
|
||
def set_lora_up_weights(text_lora_attn_procs, randn_weight=False): | ||
for _, attn_proc in text_lora_attn_procs.items(): | ||
# set up.weights | ||
for layer_name, layer_module in attn_proc.named_modules(): | ||
if layer_name.endswith("_lora"): | ||
weight = ( | ||
torch.randn_like(layer_module.up.weight) | ||
if randn_weight | ||
else torch.zeros_like(layer_module.up.weight) | ||
) | ||
layer_module.up.weight = torch.nn.Parameter(weight) | ||
|
||
|
||
class LoraLoaderMixinTests(unittest.TestCase): | ||
def get_dummy_components(self): | ||
torch.manual_seed(0) | ||
|
@@ -212,3 +233,61 @@ def test_lora_save_load_legacy(self): | |
|
||
# Outputs shouldn't match. | ||
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) | ||
|
||
# copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb | ||
def get_dummy_tokens(self): | ||
max_seq_length = 77 | ||
|
||
inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0)) | ||
|
||
prepared_inputs = {} | ||
prepared_inputs["input_ids"] = inputs | ||
return prepared_inputs | ||
|
||
def test_text_encoder_lora_monkey_patch(self): | ||
pipeline_components, _ = self.get_dummy_components() | ||
pipe = StableDiffusionPipeline(**pipeline_components) | ||
|
||
dummy_tokens = self.get_dummy_tokens() | ||
|
||
# inference without lora | ||
outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] | ||
assert outputs_without_lora.shape == (1, 77, 32) | ||
|
||
# create lora_attn_procs with zeroed out up.weights | ||
text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder) | ||
set_lora_up_weights(text_attn_procs, randn_weight=False) | ||
|
||
# monkey patch | ||
pipe._modify_text_encoder(text_attn_procs) | ||
|
||
# verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor. | ||
del text_attn_procs | ||
gc.collect() | ||
|
||
# inference with lora | ||
outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] | ||
assert outputs_with_lora.shape == (1, 77, 32) | ||
|
||
assert torch.allclose( | ||
outputs_without_lora, outputs_with_lora | ||
), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs" | ||
|
||
# create lora_attn_procs with randn up.weights | ||
text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder) | ||
set_lora_up_weights(text_attn_procs, randn_weight=True) | ||
|
||
# monkey patch | ||
pipe._modify_text_encoder(text_attn_procs) | ||
|
||
# verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor. | ||
del text_attn_procs | ||
gc.collect() | ||
|
||
# inference with lora | ||
outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] | ||
assert outputs_with_lora.shape == (1, 77, 32) | ||
|
||
assert not torch.allclose( | ||
outputs_without_lora, outputs_with_lora | ||
), "lora_up_weight are not zero, so the lora outputs should be different to without lora outputs" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
However, does the following need to be changed since that LoRA layer mapping is now being changed?
diffusers/src/diffusers/loaders.py
Line 74 in 2f997f3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cc: @patrickvonplaten do we need to change the split keys here for the text encoder since https://github.com/huggingface/diffusers/pull/3505/files#diff-eca8763d65ac395bf286dd84b25abcf5e299a737508155270a648534d781e865R34
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch! The current code becomes a problem when the
remap_key()
function is called, so I've fixed it 160a4d3. However, it seems that the current test code and dreambooth script do not fall into the conditions where thisremap_key()
function is called, so I haven't been able to test it.As far as I understand, the intention here is that this part is called when
AttnProcsLayers
is directly loaded withload_state_dict()
, but are there any use cases where this part is called?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Yeah I think so too. @patrickvonplaten could you confirm once?