From fb708fba19276c08fdce5849447d725293cbf962 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Tue, 16 May 2023 04:13:34 +0900 Subject: [PATCH 1/6] fix monkey-patch for text_encoder --- src/diffusers/loaders.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index e50bc31a5c63..f9840bca626e 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -943,14 +943,16 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): module = self.text_encoder.get_submodule(name) # Construct a new function that performs the LoRA merging. We will monkey patch # this forward pass. - lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) - old_forward = module.forward - def new_forward(x): - return old_forward(x) + lora_layer(x) + if name in attn_processors: + module.lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) + module.old_forward = module.forward - # Monkey-patch. - module.forward = new_forward + def new_forward(self, x): + return self.old_forward(x) + self.lora_layer(x) + + # Monkey-patch. + module.forward = new_forward.__get__(module) def _get_lora_layer_attribute(self, name: str) -> str: if "q_proj" in name: From 6e8f3ab897a6c068b5ac997887cd79dbef6618d0 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Sat, 20 May 2023 00:29:14 +0900 Subject: [PATCH 2/6] add test_text_encoder_lora_monkey_patch() --- tests/models/test_lora_layers.py | 62 ++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 6f1e85e15558..ffdc7569d2e1 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -212,3 +212,65 @@ def test_lora_save_load_legacy(self): # Outputs shouldn't match. self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) + + # copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb + def get_dummy_tokens(self): + max_seq_length = 77 + + inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0)).to("cuda") + + prepared_inputs = {} + prepared_inputs["input_ids"] = inputs + return prepared_inputs + + def test_text_encoder_lora_monkey_patch(self): + pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to("cuda") + + dummy_tokens = self.get_dummy_tokens() + + # inference without lora + outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] + assert outputs_without_lora.shape == (1, 77, 768) + + text_lora_attn_procs = {} + for name, module in pipe.text_encoder.named_modules(): + if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): + text_lora_attn_procs[name] = LoRAAttnProcessor( + hidden_size=module.out_features, cross_attention_dim=None + ).to("cuda") + + # monkey patch + pipe._modify_text_encoder(text_lora_attn_procs) + + # make sure that the lora_up.weights are zeroed out + for name, attn_proc in text_lora_attn_procs.items(): + for n in ["q", "k", "v", "out"]: + n = f"to_{n}_lora" + lora_linear_layer = getattr(attn_proc, n) + lora_up_weight = lora_linear_layer.up.weight + assert torch.allclose( + lora_up_weight, torch.zeros_like(lora_up_weight) + ), "lora_up_weight should be zeroed out" + + # inference with lora + outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] + assert outputs_with_lora.shape == (1, 77, 768) + + assert torch.allclose( + outputs_without_lora, outputs_with_lora + ), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs" + + # make lora_up.weights as random + for name, attn_proc in text_lora_attn_procs.items(): + for n in ["q", "k", "v", "out"]: + n = f"to_{n}_lora" + lora_linear_layer = getattr(attn_proc, n) + lora_linear_layer.up.weight = torch.nn.Parameter(torch.randn_like(lora_linear_layer.up.weight)) + + # inference with lora + outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] + assert outputs_with_lora.shape == (1, 77, 768) + + assert not torch.allclose( + outputs_without_lora, outputs_with_lora + ), "lora_up_weight are not zero, so the lora outputs should be different to without lora outputs" From 851175565342669deeb59aa95e446f31b4b9b256 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Sat, 20 May 2023 01:20:00 +0900 Subject: [PATCH 3/6] verify that it's okay to release the attn_procs --- tests/models/test_lora_layers.py | 42 +++++++++++++++++--------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index ffdc7569d2e1..24043544a74d 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import gc import os import tempfile import unittest @@ -22,7 +23,7 @@ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin -from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.models.attention_processor import LoRAAttnProcessor, LoRALinearLayer from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, floats_tensor, torch_device @@ -232,25 +233,27 @@ def test_text_encoder_lora_monkey_patch(self): outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] assert outputs_without_lora.shape == (1, 77, 768) + # create lora_attn_procs with zeroed out up.weights text_lora_attn_procs = {} for name, module in pipe.text_encoder.named_modules(): if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): - text_lora_attn_procs[name] = LoRAAttnProcessor( - hidden_size=module.out_features, cross_attention_dim=None - ).to("cuda") + attn_proc = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None).to("cuda") + + # make sure that the up.weights are zeroed out + for layer_name, layer_module in attn_proc.named_modules(): + if layer_name.endswith("_lora"): + assert torch.allclose( + layer_module.up.weight, torch.zeros_like(layer_module.up.weight) + ), "lora_up_weight should be zeroed out" + + text_lora_attn_procs[name] = attn_proc # monkey patch pipe._modify_text_encoder(text_lora_attn_procs) - # make sure that the lora_up.weights are zeroed out - for name, attn_proc in text_lora_attn_procs.items(): - for n in ["q", "k", "v", "out"]: - n = f"to_{n}_lora" - lora_linear_layer = getattr(attn_proc, n) - lora_up_weight = lora_linear_layer.up.weight - assert torch.allclose( - lora_up_weight, torch.zeros_like(lora_up_weight) - ), "lora_up_weight should be zeroed out" + # verify that it's okay to release the text_lora_attn_procs which holds the LoRAAttnProcessor. + del text_lora_attn_procs + gc.collect() # inference with lora outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] @@ -260,12 +263,13 @@ def test_text_encoder_lora_monkey_patch(self): outputs_without_lora, outputs_with_lora ), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs" - # make lora_up.weights as random - for name, attn_proc in text_lora_attn_procs.items(): - for n in ["q", "k", "v", "out"]: - n = f"to_{n}_lora" - lora_linear_layer = getattr(attn_proc, n) - lora_linear_layer.up.weight = torch.nn.Parameter(torch.randn_like(lora_linear_layer.up.weight)) + # set randn to lora_up.weights + for name, _ in pipe.text_encoder.named_modules(): + if any(name.endswith(x) for x in TEXT_ENCODER_TARGET_MODULES): + module = pipe.text_encoder.get_submodule(name) + assert hasattr(module, "lora_layer"), "lora_layer should be added" + assert isinstance(module.lora_layer, LoRALinearLayer), "lora_layer should be LoRALinearLayer" + module.lora_layer.up.weight = torch.nn.Parameter(torch.randn_like(module.lora_layer.up.weight)) # inference with lora outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] From 81915f48dfff3cd2e2654bc820088572f4e8f5db Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Sat, 20 May 2023 03:47:58 +0900 Subject: [PATCH 4/6] fix closure version --- src/diffusers/loaders.py | 15 ++++----- tests/models/test_lora_layers.py | 53 ++++++++++++++++++-------------- 2 files changed, 38 insertions(+), 30 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index f9840bca626e..ad1096f65c21 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -943,16 +943,17 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): module = self.text_encoder.get_submodule(name) # Construct a new function that performs the LoRA merging. We will monkey patch # this forward pass. + lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) + old_forward = module.forward - if name in attn_processors: - module.lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) - module.old_forward = module.forward + def make_new_forward(old_forward, lora_layer): + def new_forward(x): + return old_forward(x) + lora_layer(x) - def new_forward(self, x): - return self.old_forward(x) + self.lora_layer(x) + return new_forward - # Monkey-patch. - module.forward = new_forward.__get__(module) + # Monkey-patch. + module.forward = make_new_forward(old_forward, lora_layer) def _get_lora_layer_attribute(self, name: str) -> str: if "q_proj" in name: diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 24043544a74d..6cf79a0c11cb 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -23,7 +23,7 @@ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin -from diffusers.models.attention_processor import LoRAAttnProcessor, LoRALinearLayer +from diffusers.models.attention_processor import LoRAAttnProcessor from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, floats_tensor, torch_device @@ -218,14 +218,31 @@ def test_lora_save_load_legacy(self): def get_dummy_tokens(self): max_seq_length = 77 - inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0)).to("cuda") + inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0)) prepared_inputs = {} prepared_inputs["input_ids"] = inputs return prepared_inputs + def get_text_lora_attn_procs(self, text_encoder: nn.Module, randn_weight=False): + text_lora_attn_procs = {} + for name, module in text_encoder.named_modules(): + if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): + attn_proc = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None) + # set up.weights + for layer_name, layer_module in attn_proc.named_modules(): + if layer_name.endswith("_lora"): + weight = ( + torch.randn_like(layer_module.up.weight) + if randn_weight + else torch.zeros_like(layer_module.up.weight) + ) + layer_module.up.weight = torch.nn.Parameter(weight) + text_lora_attn_procs[name] = attn_proc + return text_lora_attn_procs + def test_text_encoder_lora_monkey_patch(self): - pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to("cuda") + pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") dummy_tokens = self.get_dummy_tokens() @@ -234,19 +251,7 @@ def test_text_encoder_lora_monkey_patch(self): assert outputs_without_lora.shape == (1, 77, 768) # create lora_attn_procs with zeroed out up.weights - text_lora_attn_procs = {} - for name, module in pipe.text_encoder.named_modules(): - if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): - attn_proc = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None).to("cuda") - - # make sure that the up.weights are zeroed out - for layer_name, layer_module in attn_proc.named_modules(): - if layer_name.endswith("_lora"): - assert torch.allclose( - layer_module.up.weight, torch.zeros_like(layer_module.up.weight) - ), "lora_up_weight should be zeroed out" - - text_lora_attn_procs[name] = attn_proc + text_lora_attn_procs = self.get_text_lora_attn_procs(pipe.text_encoder, randn_weight=False) # monkey patch pipe._modify_text_encoder(text_lora_attn_procs) @@ -263,13 +268,15 @@ def test_text_encoder_lora_monkey_patch(self): outputs_without_lora, outputs_with_lora ), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs" - # set randn to lora_up.weights - for name, _ in pipe.text_encoder.named_modules(): - if any(name.endswith(x) for x in TEXT_ENCODER_TARGET_MODULES): - module = pipe.text_encoder.get_submodule(name) - assert hasattr(module, "lora_layer"), "lora_layer should be added" - assert isinstance(module.lora_layer, LoRALinearLayer), "lora_layer should be LoRALinearLayer" - module.lora_layer.up.weight = torch.nn.Parameter(torch.randn_like(module.lora_layer.up.weight)) + # create lora_attn_procs with randn up.weights + text_lora_attn_procs = self.get_text_lora_attn_procs(pipe.text_encoder, randn_weight=True) + + # monkey patch + pipe._modify_text_encoder(text_lora_attn_procs) + + # verify that it's okay to release the text_lora_attn_procs which holds the LoRAAttnProcessor. + del text_lora_attn_procs + gc.collect() # inference with lora outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] From 88db546c01eff271025ab1581f467f41be337c3f Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Sat, 20 May 2023 03:53:05 +0900 Subject: [PATCH 5/6] add comment --- src/diffusers/loaders.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index ad1096f65c21..7eb389184ed9 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -946,6 +946,7 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) old_forward = module.forward + # create a new scope that locks in the old_forward, lora_layer value for each new_forward function def make_new_forward(old_forward, lora_layer): def new_forward(x): return old_forward(x) + lora_layer(x) From 1da772b9fe8f64702989c1319348dfffd65dc491 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Tue, 23 May 2023 00:02:01 +0900 Subject: [PATCH 6/6] Fix to reuse utility functions --- tests/models/test_lora_layers.py | 64 +++++++++++++++++--------------- 1 file changed, 34 insertions(+), 30 deletions(-) diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 6cf79a0c11cb..528c6e8bc35a 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -44,15 +44,33 @@ def create_unet_lora_layers(unet: nn.Module): return lora_attn_procs, unet_lora_layers -def create_text_encoder_lora_layers(text_encoder: nn.Module): +def create_text_encoder_lora_attn_procs(text_encoder: nn.Module): text_lora_attn_procs = {} for name, module in text_encoder.named_modules(): if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): text_lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None) + return text_lora_attn_procs + + +def create_text_encoder_lora_layers(text_encoder: nn.Module): + text_lora_attn_procs = create_text_encoder_lora_attn_procs(text_encoder) text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) return text_encoder_lora_layers +def set_lora_up_weights(text_lora_attn_procs, randn_weight=False): + for _, attn_proc in text_lora_attn_procs.items(): + # set up.weights + for layer_name, layer_module in attn_proc.named_modules(): + if layer_name.endswith("_lora"): + weight = ( + torch.randn_like(layer_module.up.weight) + if randn_weight + else torch.zeros_like(layer_module.up.weight) + ) + layer_module.up.weight = torch.nn.Parameter(weight) + + class LoraLoaderMixinTests(unittest.TestCase): def get_dummy_components(self): torch.manual_seed(0) @@ -224,63 +242,49 @@ def get_dummy_tokens(self): prepared_inputs["input_ids"] = inputs return prepared_inputs - def get_text_lora_attn_procs(self, text_encoder: nn.Module, randn_weight=False): - text_lora_attn_procs = {} - for name, module in text_encoder.named_modules(): - if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): - attn_proc = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None) - # set up.weights - for layer_name, layer_module in attn_proc.named_modules(): - if layer_name.endswith("_lora"): - weight = ( - torch.randn_like(layer_module.up.weight) - if randn_weight - else torch.zeros_like(layer_module.up.weight) - ) - layer_module.up.weight = torch.nn.Parameter(weight) - text_lora_attn_procs[name] = attn_proc - return text_lora_attn_procs - def test_text_encoder_lora_monkey_patch(self): - pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") + pipeline_components, _ = self.get_dummy_components() + pipe = StableDiffusionPipeline(**pipeline_components) dummy_tokens = self.get_dummy_tokens() # inference without lora outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] - assert outputs_without_lora.shape == (1, 77, 768) + assert outputs_without_lora.shape == (1, 77, 32) # create lora_attn_procs with zeroed out up.weights - text_lora_attn_procs = self.get_text_lora_attn_procs(pipe.text_encoder, randn_weight=False) + text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder) + set_lora_up_weights(text_attn_procs, randn_weight=False) # monkey patch - pipe._modify_text_encoder(text_lora_attn_procs) + pipe._modify_text_encoder(text_attn_procs) - # verify that it's okay to release the text_lora_attn_procs which holds the LoRAAttnProcessor. - del text_lora_attn_procs + # verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor. + del text_attn_procs gc.collect() # inference with lora outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] - assert outputs_with_lora.shape == (1, 77, 768) + assert outputs_with_lora.shape == (1, 77, 32) assert torch.allclose( outputs_without_lora, outputs_with_lora ), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs" # create lora_attn_procs with randn up.weights - text_lora_attn_procs = self.get_text_lora_attn_procs(pipe.text_encoder, randn_weight=True) + text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder) + set_lora_up_weights(text_attn_procs, randn_weight=True) # monkey patch - pipe._modify_text_encoder(text_lora_attn_procs) + pipe._modify_text_encoder(text_attn_procs) - # verify that it's okay to release the text_lora_attn_procs which holds the LoRAAttnProcessor. - del text_lora_attn_procs + # verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor. + del text_attn_procs gc.collect() # inference with lora outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] - assert outputs_with_lora.shape == (1, 77, 768) + assert outputs_with_lora.shape == (1, 77, 32) assert not torch.allclose( outputs_without_lora, outputs_with_lora