Skip to content

Fix to generate one LoRAAttnProcessor for each CLIPAttention in TextEncoder LoRA #3505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
6 changes: 3 additions & 3 deletions examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
SlicedAttnAddedKVProcessor,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, check_min_version, is_wandb_available
from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available


Expand Down Expand Up @@ -839,9 +839,9 @@ def main(args):
if args.train_text_encoder:
text_lora_attn_procs = {}
for name, module in text_encoder.named_modules():
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
text_lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=module.out_features, cross_attention_dim=None
hidden_size=module.out_proj.out_features, cross_attention_dim=None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

However, does the following need to be changed since that LoRA layer mapping is now being changed?

self.split_keys = [".processor", ".k_proj", ".q_proj", ".v_proj", ".out_proj"]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch! The current code becomes a problem when the remap_key() function is called, so I've fixed it 160a4d3. However, it seems that the current test code and dreambooth script do not fall into the conditions where this remap_key() function is called, so I haven't been able to test it.

As far as I understand, the intention here is that this part is called when AttnProcsLayers is directly loaded with load_state_dict(), but are there any use cases where this part is called?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Yeah I think so too. @patrickvonplaten could you confirm once?

)
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
temp_pipeline = StableDiffusionPipeline.from_pretrained(
Expand Down
18 changes: 12 additions & 6 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def __init__(self, state_dict: Dict[str, torch.Tensor]):
self.mapping = dict(enumerate(state_dict.keys()))
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}

# .processor for unet, .k_proj, ".q_proj", ".v_proj", and ".out_proj" for text encoder
self.split_keys = [".processor", ".k_proj", ".q_proj", ".v_proj", ".out_proj"]
# .processor for unet, .self_attn for text encoder
self.split_keys = [".processor", ".self_attn"]

# we add a hook to state_dict() and load_state_dict() so that the
# naming fits with `unet.attn_processors`
Expand Down Expand Up @@ -943,14 +943,20 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
module = self.text_encoder.get_submodule(name)
# Construct a new function that performs the LoRA merging. We will monkey patch
# this forward pass.
lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name))
attn_processor_name = ".".join(name.split(".")[:-1])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To get the correct mapping in the names as discovered in #3437 (comment)

lora_layer = getattr(attn_processors[attn_processor_name], self._get_lora_layer_attribute(name))
old_forward = module.forward

def new_forward(x):
return old_forward(x) + lora_layer(x)
# create a new scope that locks in the old_forward, lora_layer value for each new_forward function
# for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
def make_new_forward(old_forward, lora_layer):
def new_forward(x):
return old_forward(x) + lora_layer(x)

return new_forward

# Monkey-patch.
module.forward = new_forward
module.forward = make_new_forward(old_forward, lora_layer)

def _get_lora_layer_attribute(self, name: str) -> str:
if "q_proj" in name:
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ONNX_EXTERNAL_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
TEXT_ENCODER_ATTN_MODULE,
TEXT_ENCODER_TARGET_MODULES,
WEIGHTS_NAME,
)
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj", "k_proj", "out_proj"]
TEXT_ENCODER_ATTN_MODULE = ".self_attn"
87 changes: 83 additions & 4 deletions tests/models/test_lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import os
import tempfile
import unittest
Expand All @@ -23,7 +24,7 @@
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, floats_tensor, torch_device
from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, floats_tensor, torch_device


def create_unet_lora_layers(unet: nn.Module):
Expand All @@ -43,15 +44,35 @@ def create_unet_lora_layers(unet: nn.Module):
return lora_attn_procs, unet_lora_layers


def create_text_encoder_lora_layers(text_encoder: nn.Module):
def create_text_encoder_lora_attn_procs(text_encoder: nn.Module):
text_lora_attn_procs = {}
for name, module in text_encoder.named_modules():
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
text_lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None)
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
text_lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=module.out_proj.out_features, cross_attention_dim=None
)
return text_lora_attn_procs


def create_text_encoder_lora_layers(text_encoder: nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we're not using this method. Okay for me to discard.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've left this as it still seems to be used inside LoraLoaderMixinTests.get_dummy_components().

text_lora_attn_procs = create_text_encoder_lora_attn_procs(text_encoder)
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
return text_encoder_lora_layers


def set_lora_up_weights(text_lora_attn_procs, randn_weight=False):
for _, attn_proc in text_lora_attn_procs.items():
# set up.weights
for layer_name, layer_module in attn_proc.named_modules():
if layer_name.endswith("_lora"):
weight = (
torch.randn_like(layer_module.up.weight)
if randn_weight
else torch.zeros_like(layer_module.up.weight)
)
layer_module.up.weight = torch.nn.Parameter(weight)


class LoraLoaderMixinTests(unittest.TestCase):
def get_dummy_components(self):
torch.manual_seed(0)
Expand Down Expand Up @@ -212,3 +233,61 @@ def test_lora_save_load_legacy(self):

# Outputs shouldn't match.
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))

# copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
def get_dummy_tokens(self):
max_seq_length = 77

inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0))

prepared_inputs = {}
prepared_inputs["input_ids"] = inputs
return prepared_inputs

def test_text_encoder_lora_monkey_patch(self):
pipeline_components, _ = self.get_dummy_components()
pipe = StableDiffusionPipeline(**pipeline_components)

dummy_tokens = self.get_dummy_tokens()

# inference without lora
outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0]
assert outputs_without_lora.shape == (1, 77, 32)

# create lora_attn_procs with zeroed out up.weights
text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder)
set_lora_up_weights(text_attn_procs, randn_weight=False)

# monkey patch
pipe._modify_text_encoder(text_attn_procs)

# verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor.
del text_attn_procs
gc.collect()

# inference with lora
outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0]
assert outputs_with_lora.shape == (1, 77, 32)

assert torch.allclose(
outputs_without_lora, outputs_with_lora
), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs"

# create lora_attn_procs with randn up.weights
text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder)
set_lora_up_weights(text_attn_procs, randn_weight=True)

# monkey patch
pipe._modify_text_encoder(text_attn_procs)

# verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor.
del text_attn_procs
gc.collect()

# inference with lora
outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0]
assert outputs_with_lora.shape == (1, 77, 32)

assert not torch.allclose(
outputs_without_lora, outputs_with_lora
), "lora_up_weight are not zero, so the lora outputs should be different to without lora outputs"