diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index e50bc31a5c63..7eb389184ed9 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -946,11 +946,15 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) old_forward = module.forward - def new_forward(x): - return old_forward(x) + lora_layer(x) + # create a new scope that locks in the old_forward, lora_layer value for each new_forward function + def make_new_forward(old_forward, lora_layer): + def new_forward(x): + return old_forward(x) + lora_layer(x) + + return new_forward # Monkey-patch. - module.forward = new_forward + module.forward = make_new_forward(old_forward, lora_layer) def _get_lora_layer_attribute(self, name: str) -> str: if "q_proj" in name: diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 6f1e85e15558..528c6e8bc35a 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import gc import os import tempfile import unittest @@ -43,15 +44,33 @@ def create_unet_lora_layers(unet: nn.Module): return lora_attn_procs, unet_lora_layers -def create_text_encoder_lora_layers(text_encoder: nn.Module): +def create_text_encoder_lora_attn_procs(text_encoder: nn.Module): text_lora_attn_procs = {} for name, module in text_encoder.named_modules(): if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): text_lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None) + return text_lora_attn_procs + + +def create_text_encoder_lora_layers(text_encoder: nn.Module): + text_lora_attn_procs = create_text_encoder_lora_attn_procs(text_encoder) text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) return text_encoder_lora_layers +def set_lora_up_weights(text_lora_attn_procs, randn_weight=False): + for _, attn_proc in text_lora_attn_procs.items(): + # set up.weights + for layer_name, layer_module in attn_proc.named_modules(): + if layer_name.endswith("_lora"): + weight = ( + torch.randn_like(layer_module.up.weight) + if randn_weight + else torch.zeros_like(layer_module.up.weight) + ) + layer_module.up.weight = torch.nn.Parameter(weight) + + class LoraLoaderMixinTests(unittest.TestCase): def get_dummy_components(self): torch.manual_seed(0) @@ -212,3 +231,61 @@ def test_lora_save_load_legacy(self): # Outputs shouldn't match. self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) + + # copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb + def get_dummy_tokens(self): + max_seq_length = 77 + + inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0)) + + prepared_inputs = {} + prepared_inputs["input_ids"] = inputs + return prepared_inputs + + def test_text_encoder_lora_monkey_patch(self): + pipeline_components, _ = self.get_dummy_components() + pipe = StableDiffusionPipeline(**pipeline_components) + + dummy_tokens = self.get_dummy_tokens() + + # inference without lora + outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] + assert outputs_without_lora.shape == (1, 77, 32) + + # create lora_attn_procs with zeroed out up.weights + text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder) + set_lora_up_weights(text_attn_procs, randn_weight=False) + + # monkey patch + pipe._modify_text_encoder(text_attn_procs) + + # verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor. + del text_attn_procs + gc.collect() + + # inference with lora + outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] + assert outputs_with_lora.shape == (1, 77, 32) + + assert torch.allclose( + outputs_without_lora, outputs_with_lora + ), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs" + + # create lora_attn_procs with randn up.weights + text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder) + set_lora_up_weights(text_attn_procs, randn_weight=True) + + # monkey patch + pipe._modify_text_encoder(text_attn_procs) + + # verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor. + del text_attn_procs + gc.collect() + + # inference with lora + outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] + assert outputs_with_lora.shape == (1, 77, 32) + + assert not torch.allclose( + outputs_without_lora, outputs_with_lora + ), "lora_up_weight are not zero, so the lora outputs should be different to without lora outputs"