diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index e657406912f2..1058ff352b67 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -34,7 +34,7 @@ from .utils import ( DIFFUSERS_CACHE, HF_HUB_OFFLINE, - TEXT_ENCODER_TARGET_MODULES, + TEXT_ENCODER_ATTN_MODULE, _get_model_file, deprecate, is_safetensors_available, @@ -955,6 +955,19 @@ def text_encoder_lora_attn_procs(self): return self._text_encoder_lora_attn_procs return + def _remove_text_encoder_monkey_patch(self): + # Loop over the CLIPAttention module of text_encoder + for name, attn_module in self.text_encoder.named_modules(): + if name.endswith(TEXT_ENCODER_ATTN_MODULE): + # Loop over the LoRA layers + for _, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items(): + # Retrieve the q/k/v/out projection of CLIPAttention + module = attn_module.get_submodule(text_encoder_attr) + if hasattr(module, "old_forward"): + # restore original `forward` to remove monkey-patch + module.forward = module.old_forward + delattr(module, "old_forward") + def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): r""" Monkey-patches the forward passes of attention modules of the text encoder. @@ -963,37 +976,41 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): attn_processors: Dict[str, `LoRAAttnProcessor`]: A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`]. """ - # Loop over the original attention modules. - for name, _ in self.text_encoder.named_modules(): - if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): - # Retrieve the module and its corresponding LoRA processor. - module = self.text_encoder.get_submodule(name) - # Construct a new function that performs the LoRA merging. We will monkey patch - # this forward pass. - attn_processor_name = ".".join(name.split(".")[:-1]) - lora_layer = getattr(attn_processors[attn_processor_name], self._get_lora_layer_attribute(name)) - old_forward = module.forward - - # create a new scope that locks in the old_forward, lora_layer value for each new_forward function - # for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060 - def make_new_forward(old_forward, lora_layer): - def new_forward(x): - return old_forward(x) + lora_layer(x) - - return new_forward - - # Monkey-patch. - module.forward = make_new_forward(old_forward, lora_layer) - - def _get_lora_layer_attribute(self, name: str) -> str: - if "q_proj" in name: - return "to_q_lora" - elif "v_proj" in name: - return "to_v_lora" - elif "k_proj" in name: - return "to_k_lora" - else: - return "to_out_lora" + + # First, remove any monkey-patch that might have been applied before + self._remove_text_encoder_monkey_patch() + + # Loop over the CLIPAttention module of text_encoder + for name, attn_module in self.text_encoder.named_modules(): + if name.endswith(TEXT_ENCODER_ATTN_MODULE): + # Loop over the LoRA layers + for attn_proc_attr, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items(): + # Retrieve the q/k/v/out projection of CLIPAttention and its corresponding LoRA layer. + module = attn_module.get_submodule(text_encoder_attr) + lora_layer = attn_processors[name].get_submodule(attn_proc_attr) + + # save old_forward to module that can be used to remove monkey-patch + old_forward = module.old_forward = module.forward + + # create a new scope that locks in the old_forward, lora_layer value for each new_forward function + # for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060 + def make_new_forward(old_forward, lora_layer): + def new_forward(x): + return old_forward(x) + lora_layer(x) + + return new_forward + + # Monkey-patch. + module.forward = make_new_forward(old_forward, lora_layer) + + @property + def _lora_attn_processor_attr_to_text_encoder_attr(self): + return { + "to_q_lora": "q_proj", + "to_k_lora": "k_proj", + "to_v_lora": "v_proj", + "to_out_lora": "out_proj", + } def _load_text_encoder_attn_procs( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 772c36b1177b..36cbe82f79e7 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -31,7 +31,6 @@ ONNX_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, TEXT_ENCODER_ATTN_MODULE, - TEXT_ENCODER_TARGET_MODULES, WEIGHTS_NAME, ) from .deprecation_utils import deprecate diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 93d5c8cc42cd..3c641a259a81 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -30,5 +30,4 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] -TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj", "k_proj", "out_proj"] TEXT_ENCODER_ATTN_MODULE = ".self_attn" diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index d04d87e08b7a..52826fc0c736 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -163,6 +163,15 @@ def get_dummy_inputs(self): return noise, input_ids, pipeline_inputs + def create_lora_weight_file(self, tmpdirname): + _, lora_components = self.get_dummy_components() + LoraLoaderMixin.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + def test_lora_save_load(self): pipeline_components, lora_components = self.get_dummy_components() sd_pipe = StableDiffusionPipeline(**pipeline_components) @@ -299,14 +308,45 @@ def test_text_encoder_lora_monkey_patch(self): outputs_without_lora, outputs_with_lora ), "lora_up_weight are not zero, so the lora outputs should be different to without lora outputs" - def create_lora_weight_file(self, tmpdirname): - _, lora_components = self.get_dummy_components() - LoraLoaderMixin.save_lora_weights( - save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_layers"], - text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], - ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + def test_text_encoder_lora_remove_monkey_patch(self): + pipeline_components, _ = self.get_dummy_components() + pipe = StableDiffusionPipeline(**pipeline_components) + + dummy_tokens = self.get_dummy_tokens() + + # inference without lora + outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] + assert outputs_without_lora.shape == (1, 77, 32) + + # create lora_attn_procs with randn up.weights + text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder) + set_lora_up_weights(text_attn_procs, randn_weight=True) + + # monkey patch + pipe._modify_text_encoder(text_attn_procs) + + # verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor. + del text_attn_procs + gc.collect() + + # inference with lora + outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] + assert outputs_with_lora.shape == (1, 77, 32) + + assert not torch.allclose( + outputs_without_lora, outputs_with_lora + ), "lora outputs should be different to without lora outputs" + + # remove monkey patch + pipe._remove_text_encoder_monkey_patch() + + # inference with removed lora + outputs_without_lora_removed = pipe.text_encoder(**dummy_tokens)[0] + assert outputs_without_lora_removed.shape == (1, 77, 32) + + assert torch.allclose( + outputs_without_lora, outputs_without_lora_removed + ), "remove lora monkey patch should restore the original outputs" def test_lora_unet_attn_processors(self): with tempfile.TemporaryDirectory() as tmpdirname: