From d0b07dfcc912de90e8145b8d6af78e8406ce21b8 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Thu, 25 May 2023 03:05:10 +0900 Subject: [PATCH 1/4] fix to use LoRAXFormersAttnProcessor --- src/diffusers/loaders.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index e50bc31a5c63..ceb5d2915bea 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -27,7 +27,9 @@ CustomDiffusionXFormersAttnProcessor, LoRAAttnAddedKVProcessor, LoRAAttnProcessor, + LoRAXFormersAttnProcessor, SlicedAttnAddedKVProcessor, + XFormersAttnProcessor, ) from .utils import ( DIFFUSERS_CACHE, @@ -279,7 +281,10 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict attn_processor_class = LoRAAttnAddedKVProcessor else: cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] - attn_processor_class = LoRAAttnProcessor + if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)): + attn_processor_class = LoRAXFormersAttnProcessor + else: + attn_processor_class = LoRAAttnProcessor attn_processors[key] = attn_processor_class( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank From c593dca8572371708b7f480e8dab68791b6932e0 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Fri, 26 May 2023 01:04:15 +0900 Subject: [PATCH 2/4] add test --- tests/models/test_lora_layers.py | 65 +++++++++++++++++++++++++++++++- 1 file changed, 64 insertions(+), 1 deletion(-) diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 6f1e85e15558..55f42c30fb24 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -22,7 +22,14 @@ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin -from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.models.attention_processor import ( + Attention, + AttnProcessor, + AttnProcessor2_0, + LoRAAttnProcessor, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, floats_tensor, torch_device @@ -212,3 +219,59 @@ def test_lora_save_load_legacy(self): # Outputs shouldn't match. self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) + + def create_lora_weight_file(self, tmpdirname): + pipeline_components, lora_components = self.get_dummy_components() + unet_lora_attn_procs = lora_components["unet_lora_attn_procs"] + sd_pipe = StableDiffusionPipeline(**pipeline_components) + sd_pipe.unet.set_attn_processor(unet_lora_attn_procs) + sd_pipe.unet.save_attn_procs(tmpdirname) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + + def test_lora_unet_attn_processors(self): + with tempfile.TemporaryDirectory() as tmpdirname: + self.create_lora_weight_file(tmpdirname) + + pipeline_components, _ = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + # check if vanilla attention processors are used + for _, module in sd_pipe.unet.named_modules(): + if isinstance(module, Attention): + self.assertIsInstance(module.processor, (AttnProcessor, AttnProcessor2_0)) + + # load LoRA weight file + sd_pipe.load_lora_weights(tmpdirname) + + # check if lora attention processors are used + for _, module in sd_pipe.unet.named_modules(): + if isinstance(module, Attention): + self.assertIsInstance(module.processor, LoRAAttnProcessor) + + @unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU") + def test_lora_unet_attn_processors_with_xformers(self): + with tempfile.TemporaryDirectory() as tmpdirname: + self.create_lora_weight_file(tmpdirname) + + pipeline_components, _ = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + # enable XFormers + sd_pipe.enable_xformers_memory_efficient_attention() + + # check if xFormers attention processors are used + for _, module in sd_pipe.unet.named_modules(): + if isinstance(module, Attention): + self.assertIsInstance(module.processor, XFormersAttnProcessor) + + # load LoRA weight file + sd_pipe.load_lora_weights(tmpdirname) + + # check if lora attention processors are used + for _, module in sd_pipe.unet.named_modules(): + if isinstance(module, Attention): + self.assertIsInstance(module.processor, LoRAXFormersAttnProcessor) From bcbf946510baa153ecf241d2a81ffeea1efbb725 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Fri, 26 May 2023 01:52:15 +0900 Subject: [PATCH 3/4] using new LoraLoaderMixin.save_lora_weights --- tests/models/test_lora_layers.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 55f42c30fb24..65df5736e868 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -221,11 +221,12 @@ def test_lora_save_load_legacy(self): self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) def create_lora_weight_file(self, tmpdirname): - pipeline_components, lora_components = self.get_dummy_components() - unet_lora_attn_procs = lora_components["unet_lora_attn_procs"] - sd_pipe = StableDiffusionPipeline(**pipeline_components) - sd_pipe.unet.set_attn_processor(unet_lora_attn_procs) - sd_pipe.unet.save_attn_procs(tmpdirname) + _, lora_components = self.get_dummy_components() + LoraLoaderMixin.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], + ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) def test_lora_unet_attn_processors(self): From 049f5e5fe88718fde9e9e8be0f6d86a3684632ae Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Fri, 26 May 2023 01:57:40 +0900 Subject: [PATCH 4/4] add test_lora_save_load_with_xformers --- tests/models/test_lora_layers.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 65df5736e868..64e30ba4057d 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -276,3 +276,33 @@ def test_lora_unet_attn_processors_with_xformers(self): for _, module in sd_pipe.unet.named_modules(): if isinstance(module, Attention): self.assertIsInstance(module.processor, LoRAXFormersAttnProcessor) + + @unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU") + def test_lora_save_load_with_xformers(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + noise, input_ids, pipeline_inputs = self.get_dummy_inputs() + + # enable XFormers + sd_pipe.enable_xformers_memory_efficient_attention() + + original_images = sd_pipe(**pipeline_inputs).images + orig_image_slice = original_images[0, -3:, -3:, -1] + + with tempfile.TemporaryDirectory() as tmpdirname: + LoraLoaderMixin.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + sd_pipe.load_lora_weights(tmpdirname) + + lora_images = sd_pipe(**pipeline_inputs).images + lora_image_slice = lora_images[0, -3:, -3:, -1] + + # Outputs shouldn't match. + self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))