From 6cdb47cf0dc9aae8ceffd540c0b9337d4ef0023e Mon Sep 17 00:00:00 2001 From: Pie31415 Date: Sat, 25 Mar 2023 02:18:43 -0400 Subject: [PATCH 1/4] inital commit for lora test cases --- tests/models/test_models_unet_2d_condition.py | 23 +- tests/models/test_models_unet_3d_condition.py | 227 ++++++++++++++++-- 2 files changed, 213 insertions(+), 37 deletions(-) diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index 08e960dcd1da..60ef546df709 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -465,28 +465,7 @@ def test_lora_save_load_safetensors(self): with torch.no_grad(): old_sample = model(**inputs_dict).sample - lora_attn_procs = {} - for name in model.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim - if name.startswith("mid_block"): - hidden_size = model.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(model.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = model.config.block_out_channels[block_id] - - lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) - lora_attn_procs[name] = lora_attn_procs[name].to(model.device) - - # add 1 to weights to mock trained weights - with torch.no_grad(): - lora_attn_procs[name].to_q_lora.up.weight += 1 - lora_attn_procs[name].to_k_lora.up.weight += 1 - lora_attn_procs[name].to_v_lora.up.weight += 1 - lora_attn_procs[name].to_out_lora.up.weight += 1 - + lora_attn_procs = create_lora_layers(model) model.set_attn_processor(lora_attn_procs) with torch.no_grad(): diff --git a/tests/models/test_models_unet_3d_condition.py b/tests/models/test_models_unet_3d_condition.py index ea71ae4af26c..b49903a7816a 100644 --- a/tests/models/test_models_unet_3d_condition.py +++ b/tests/models/test_models_unet_3d_condition.py @@ -13,13 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import torch +import tempfile import unittest - import numpy as np -import torch from diffusers.models import ModelMixin, UNet3DConditionModel -from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.models.attention_processor import AttnProcessor, LoRAAttnProcessor from diffusers.utils import ( floats_tensor, logging, @@ -193,23 +194,219 @@ def test_model_attention_slicing(self): output = model(**inputs_dict) assert output is not None - # (`attn_processors`) needs to be implemented in this model for this test. - # def test_lora_processors(self): + def test_lora_processors(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = 8 + + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + sample1 = model(**inputs_dict).sample + + lora_attn_procs = {} + for name in model.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = model.config.block_out_channels[-1] + elif name.startswith("up_block"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(model.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = model.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + + with torch.no_grad(): + lora_attn_procs[name].to_q_lora.up.weight += 1 + lora_attn_procs[name].to_k_lora.up.weight += 1 + lora_attn_procs[name].to_v_lora.up.weight += 1 + lora_attn_procs[name].to_out_lora.up.weight += 1 + + # make sure we can set a list of attention processors + model.set_attn_processor(lora_attn_procs) + model.to(torch_device) + + # test that attn processors can be set to itself + model.set_attn_processor(model.attn_processors) + + with torch.no_grad(): + sample2 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample + sample3 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + sample4 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + assert (sample1 - sample2).abs().max() < 1e-4 + assert (sample3 - sample4).abs().max() < 1e-4 + + # sample 2 and sample 3 should be different + assert (sample2 - sample3).abs().max() > 1e-4 - # (`attn_processors`) needs to be implemented in this model for this test. - # def test_lora_save_load(self): + def test_lora_save_load(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - # (`attn_processors`) needs to be implemented for this test in the model. - # def test_lora_save_load_safetensors(self): + init_dict["attention_head_dim"] = 8 - # (`attn_processors`) needs to be implemented for this test in the model. - # def test_lora_save_safetensors_load_torch(self): + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) - # (`attn_processors`) needs to be implemented for this test. - # def test_lora_save_torch_force_load_safetensors_error(self): + with torch.no_grad(): + old_sample = model(**inputs_dict).sample + + lora_attn_procs = create_lora_layers(model) + model.set_attn_processor(lora_attn_procs) + + with torch.no_grad(): + sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_attn_procs(tmpdirname) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + torch.manual_seed(0) + new_model = self.model_class(**init_dict) + new_model.to(torch_device) + new_model.load_attn_procs(tmpdirname) + + with torch.no_grad(): + new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + assert (sample - new_sample).abs().max() < 1e-4 + + # LoRA and no LoRA should NOT be the same + assert (sample - old_sample).abs().max() > 1e-4 + + def test_lora_save_load_safetensors(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = 8 + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + old_sample = model(**inputs_dict).sample + + lora_attn_procs = create_lora_layers(model) + model.set_attn_processor(lora_attn_procs) + + with torch.no_grad(): + sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_attn_procs(tmpdirname, safe_serialization=True) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + torch.manual_seed(0) + new_model = self.model_class(**init_dict) + new_model.to(torch_device) + new_model.load_attn_procs(tmpdirname) + + with torch.no_grad(): + new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + assert (sample - new_sample).abs().max() < 1e-4 + + # LoRA and no LoRA should NOT be the same + assert (sample - old_sample).abs().max() > 1e-4 + + def test_lora_save_safetensors_load_torch(self): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = 8 + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + lora_attn_procs = {} + for name in model.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = model.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(model.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = model.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + lora_attn_procs[name] = lora_attn_procs[name].to(model.device) + + model.set_attn_processor(lora_attn_procs) + # Saving as torch, properly reloads with directly filename + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_attn_procs(tmpdirname) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + torch.manual_seed(0) + new_model = self.model_class(**init_dict) + new_model.to(torch_device) + new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.bin") + + def test_lora_save_torch_force_load_safetensors_error(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = 8 + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + lora_attn_procs = {} + for name in model.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = model.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(model.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = model.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + lora_attn_procs[name] = lora_attn_procs[name].to(model.device) + + model.set_attn_processor(lora_attn_procs) + # Saving as torch, properly reloads with directly filename + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_attn_procs(tmpdirname) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + torch.manual_seed(0) + new_model = self.model_class(**init_dict) + new_model.to(torch_device) + with self.assertRaises(IOError) as e: + new_model.load_attn_procs(tmpdirname, use_safetensors=True) + self.assertIn("Error no file named pytorch_lora_weights.safetensors", str(e.exception)) + + def test_lora_on_off(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = 8 + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + old_sample = model(**inputs_dict).sample + + lora_attn_procs = create_lora_layers(model) + model.set_attn_processor(lora_attn_procs) + + with torch.no_grad(): + sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample + + model.set_attn_processor(AttnProcessor()) + + with torch.no_grad(): + new_sample = model(**inputs_dict).sample - # (`attn_processors`) needs to be added for this test. - # def test_lora_on_off(self): + assert (sample - new_sample).abs().max() < 1e-4 + assert (sample - old_sample).abs().max() < 1e-4 @unittest.skipIf( torch_device != "cuda" or not is_xformers_available(), From 5a208ce0dfed16bd96f6203f2892b3656297c028 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 4 Apr 2023 14:55:10 +0200 Subject: [PATCH 2/4] help a bit with lora for 3d --- src/diffusers/models/unet_3d_blocks.py | 12 +++++++++--- src/diffusers/models/unet_3d_condition.py | 4 +++- tests/models/test_models_unet_3d_condition.py | 12 ++++++++++-- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index 9f8ee2a22aab..2c86171610bf 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -251,7 +251,9 @@ def forward( encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, ).sample - hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample + hidden_states = temp_attn( + hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs + ).sample hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) @@ -376,7 +378,9 @@ def forward( encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, ).sample - hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample + hidden_states = temp_attn( + hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs + ).sample output_states += (hidden_states,) @@ -587,7 +591,9 @@ def forward( encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, ).sample - hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample + hidden_states = temp_attn( + hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs + ).sample if self.upsamplers is not None: for upsampler in self.upsamplers: diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index 8006d0e1c127..ece1a34e4fb9 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -458,7 +458,9 @@ def forward( sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) sample = self.conv_in(sample) - sample = self.transformer_in(sample, num_frames=num_frames).sample + sample = self.transformer_in( + sample, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs + ).sample # 3. down down_block_res_samples = (sample,) diff --git a/tests/models/test_models_unet_3d_condition.py b/tests/models/test_models_unet_3d_condition.py index b49903a7816a..20b3c4808081 100644 --- a/tests/models/test_models_unet_3d_condition.py +++ b/tests/models/test_models_unet_3d_condition.py @@ -14,10 +14,11 @@ # limitations under the License. import os -import torch import tempfile import unittest + import numpy as np +import torch from diffusers.models import ModelMixin, UNet3DConditionModel from diffusers.models.attention_processor import AttnProcessor, LoRAAttnProcessor @@ -207,7 +208,11 @@ def test_lora_processors(self): lora_attn_procs = {} for name in model.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim + has_cross_attention = name.endswith("attn2.processor") and not ( + name.startswith("transformer_in") or "temp_attentions" in name.split(".") + ) + cross_attention_dim = model.config.cross_attention_dim if has_cross_attention else None + if name.startswith("mid_block"): hidden_size = model.config.block_out_channels[-1] elif name.startswith("up_block"): @@ -216,6 +221,9 @@ def test_lora_processors(self): elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = model.config.block_out_channels[block_id] + elif name.startswith("transformer_in"): + # Note that the `8 * ...` comes from: https://github.com/huggingface/diffusers/blob/7139f0e874f10b2463caa8cbd585762a309d12d6/src/diffusers/models/unet_3d_condition.py#L148 + hidden_size = 8 * model.config.attention_head_dim lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) From c3373a115ecd53f87cf1d52c622cdc9c3b209935 Mon Sep 17 00:00:00 2001 From: Pie31415 Date: Tue, 4 Apr 2023 16:17:14 -0400 Subject: [PATCH 3/4] fixed lora tests --- src/diffusers/models/unet_3d_condition.py | 3 +- tests/models/test_models_unet_3d_condition.py | 30 +++++++++++++++---- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index cf6822013f7a..6fb5dfa30ebf 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -20,6 +20,7 @@ import torch.utils.checkpoint from ..configuration_utils import ConfigMixin, register_to_config +from ..loaders import UNet2DConditionLoadersMixin from ..utils import BaseOutput, logging from .attention_processor import AttentionProcessor, AttnProcessor from .embeddings import TimestepEmbedding, Timesteps @@ -50,7 +51,7 @@ class UNet3DConditionOutput(BaseOutput): sample: torch.FloatTensor -class UNet3DConditionModel(ModelMixin, ConfigMixin): +class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep and returns sample shaped output. diff --git a/tests/models/test_models_unet_3d_condition.py b/tests/models/test_models_unet_3d_condition.py index f81709f68f44..9c522a03c4a9 100644 --- a/tests/models/test_models_unet_3d_condition.py +++ b/tests/models/test_models_unet_3d_condition.py @@ -40,7 +40,10 @@ def create_lora_layers(model): lora_attn_procs = {} for name in model.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim + has_cross_attention = name.endswith("attn2.processor") and not ( + name.startswith("transformer_in") or "temp_attentions" in name.split(".") + ) + cross_attention_dim = model.config.cross_attention_dim if has_cross_attention else None if name.startswith("mid_block"): hidden_size = model.config.block_out_channels[-1] elif name.startswith("up_blocks"): @@ -49,6 +52,9 @@ def create_lora_layers(model): elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = model.config.block_out_channels[block_id] + elif name.startswith("transformer_in"): + # Note that the `8 * ...` comes from: https://github.com/huggingface/diffusers/blob/7139f0e874f10b2463caa8cbd585762a309d12d6/src/diffusers/models/unet_3d_condition.py#L148 + hidden_size = 8 * model.config.attention_head_dim lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) lora_attn_procs[name] = lora_attn_procs[name].to(model.device) @@ -328,15 +334,22 @@ def test_lora_save_safetensors_load_torch(self): lora_attn_procs = {} for name in model.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim + has_cross_attention = name.endswith("attn2.processor") and not ( + name.startswith("transformer_in") or "temp_attentions" in name.split(".") + ) + cross_attention_dim = model.config.cross_attention_dim if has_cross_attention else None + if name.startswith("mid_block"): hidden_size = model.config.block_out_channels[-1] - elif name.startswith("up_blocks"): + elif name.startswith("up_block"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(model.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = model.config.block_out_channels[block_id] + elif name.startswith("transformer_in"): + # Note that the `8 * ...` comes from: https://github.com/huggingface/diffusers/blob/7139f0e874f10b2463caa8cbd585762a309d12d6/src/diffusers/models/unet_3d_condition.py#L148 + hidden_size = 8 * model.config.attention_head_dim lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) lora_attn_procs[name] = lora_attn_procs[name].to(model.device) @@ -362,15 +375,22 @@ def test_lora_save_torch_force_load_safetensors_error(self): lora_attn_procs = {} for name in model.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim + has_cross_attention = name.endswith("attn2.processor") and not ( + name.startswith("transformer_in") or "temp_attentions" in name.split(".") + ) + cross_attention_dim = model.config.cross_attention_dim if has_cross_attention else None + if name.startswith("mid_block"): hidden_size = model.config.block_out_channels[-1] - elif name.startswith("up_blocks"): + elif name.startswith("up_block"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(model.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = model.config.block_out_channels[block_id] + elif name.startswith("transformer_in"): + # Note that the `8 * ...` comes from: https://github.com/huggingface/diffusers/blob/7139f0e874f10b2463caa8cbd585762a309d12d6/src/diffusers/models/unet_3d_condition.py#L148 + hidden_size = 8 * model.config.attention_head_dim lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) lora_attn_procs[name] = lora_attn_procs[name].to(model.device) From 85a484614e74853e70d9480bcee540dd5dc6aa68 Mon Sep 17 00:00:00 2001 From: Pie31415 Date: Mon, 10 Apr 2023 17:41:12 -0400 Subject: [PATCH 4/4] replaced redundant code --- tests/models/test_models_unet_2d_condition.py | 68 +++----------- tests/models/test_models_unet_3d_condition.py | 88 +++---------------- 2 files changed, 22 insertions(+), 134 deletions(-) diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index 86589ce02560..17e08e0a426e 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -41,7 +41,7 @@ torch.backends.cuda.matmul.allow_tf32 = False -def create_lora_layers(model): +def create_lora_layers(model, mock_weights: bool = True): lora_attn_procs = {} for name in model.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim @@ -57,12 +57,13 @@ def create_lora_layers(model): lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) lora_attn_procs[name] = lora_attn_procs[name].to(model.device) - # add 1 to weights to mock trained weights - with torch.no_grad(): - lora_attn_procs[name].to_q_lora.up.weight += 1 - lora_attn_procs[name].to_k_lora.up.weight += 1 - lora_attn_procs[name].to_v_lora.up.weight += 1 - lora_attn_procs[name].to_out_lora.up.weight += 1 + if mock_weights: + # add 1 to weights to mock trained weights + with torch.no_grad(): + lora_attn_procs[name].to_q_lora.up.weight += 1 + lora_attn_procs[name].to_k_lora.up.weight += 1 + lora_attn_procs[name].to_v_lora.up.weight += 1 + lora_attn_procs[name].to_out_lora.up.weight += 1 return lora_attn_procs @@ -378,26 +379,7 @@ def test_lora_processors(self): with torch.no_grad(): sample1 = model(**inputs_dict).sample - lora_attn_procs = {} - for name in model.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim - if name.startswith("mid_block"): - hidden_size = model.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(model.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = model.config.block_out_channels[block_id] - - lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) - - # add 1 to weights to mock trained weights - with torch.no_grad(): - lora_attn_procs[name].to_q_lora.up.weight += 1 - lora_attn_procs[name].to_k_lora.up.weight += 1 - lora_attn_procs[name].to_v_lora.up.weight += 1 - lora_attn_procs[name].to_out_lora.up.weight += 1 + lora_attn_procs = create_lora_layers(model) # make sure we can set a list of attention processors model.set_attn_processor(lora_attn_procs) @@ -497,21 +479,7 @@ def test_lora_save_safetensors_load_torch(self): model = self.model_class(**init_dict) model.to(torch_device) - lora_attn_procs = {} - for name in model.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim - if name.startswith("mid_block"): - hidden_size = model.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(model.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = model.config.block_out_channels[block_id] - - lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) - lora_attn_procs[name] = lora_attn_procs[name].to(model.device) - + lora_attn_procs = create_lora_layers(model, mock_weights=False) model.set_attn_processor(lora_attn_procs) # Saving as torch, properly reloads with directly filename with tempfile.TemporaryDirectory() as tmpdirname: @@ -532,21 +500,7 @@ def test_lora_save_torch_force_load_safetensors_error(self): model = self.model_class(**init_dict) model.to(torch_device) - lora_attn_procs = {} - for name in model.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim - if name.startswith("mid_block"): - hidden_size = model.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(model.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = model.config.block_out_channels[block_id] - - lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) - lora_attn_procs[name] = lora_attn_procs[name].to(model.device) - + lora_attn_procs = create_lora_layers(model, mock_weights=False) model.set_attn_processor(lora_attn_procs) # Saving as torch, properly reloads with directly filename with tempfile.TemporaryDirectory() as tmpdirname: diff --git a/tests/models/test_models_unet_3d_condition.py b/tests/models/test_models_unet_3d_condition.py index 9c522a03c4a9..c552b503af05 100644 --- a/tests/models/test_models_unet_3d_condition.py +++ b/tests/models/test_models_unet_3d_condition.py @@ -37,7 +37,7 @@ torch.backends.cuda.matmul.allow_tf32 = False -def create_lora_layers(model): +def create_lora_layers(model, mock_weights: bool = True): lora_attn_procs = {} for name in model.attn_processors.keys(): has_cross_attention = name.endswith("attn2.processor") and not ( @@ -59,12 +59,13 @@ def create_lora_layers(model): lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) lora_attn_procs[name] = lora_attn_procs[name].to(model.device) - # add 1 to weights to mock trained weights - with torch.no_grad(): - lora_attn_procs[name].to_q_lora.up.weight += 1 - lora_attn_procs[name].to_k_lora.up.weight += 1 - lora_attn_procs[name].to_v_lora.up.weight += 1 - lora_attn_procs[name].to_out_lora.up.weight += 1 + if mock_weights: + # add 1 to weights to mock trained weights + with torch.no_grad(): + lora_attn_procs[name].to_q_lora.up.weight += 1 + lora_attn_procs[name].to_k_lora.up.weight += 1 + lora_attn_procs[name].to_v_lora.up.weight += 1 + lora_attn_procs[name].to_out_lora.up.weight += 1 return lora_attn_procs @@ -209,32 +210,7 @@ def test_lora_processors(self): with torch.no_grad(): sample1 = model(**inputs_dict).sample - lora_attn_procs = {} - for name in model.attn_processors.keys(): - has_cross_attention = name.endswith("attn2.processor") and not ( - name.startswith("transformer_in") or "temp_attentions" in name.split(".") - ) - cross_attention_dim = model.config.cross_attention_dim if has_cross_attention else None - - if name.startswith("mid_block"): - hidden_size = model.config.block_out_channels[-1] - elif name.startswith("up_block"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(model.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = model.config.block_out_channels[block_id] - elif name.startswith("transformer_in"): - # Note that the `8 * ...` comes from: https://github.com/huggingface/diffusers/blob/7139f0e874f10b2463caa8cbd585762a309d12d6/src/diffusers/models/unet_3d_condition.py#L148 - hidden_size = 8 * model.config.attention_head_dim - - lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) - - with torch.no_grad(): - lora_attn_procs[name].to_q_lora.up.weight += 1 - lora_attn_procs[name].to_k_lora.up.weight += 1 - lora_attn_procs[name].to_v_lora.up.weight += 1 - lora_attn_procs[name].to_out_lora.up.weight += 1 + lora_attn_procs = create_lora_layers(model) # make sure we can set a list of attention processors model.set_attn_processor(lora_attn_procs) @@ -332,28 +308,7 @@ def test_lora_save_safetensors_load_torch(self): model = self.model_class(**init_dict) model.to(torch_device) - lora_attn_procs = {} - for name in model.attn_processors.keys(): - has_cross_attention = name.endswith("attn2.processor") and not ( - name.startswith("transformer_in") or "temp_attentions" in name.split(".") - ) - cross_attention_dim = model.config.cross_attention_dim if has_cross_attention else None - - if name.startswith("mid_block"): - hidden_size = model.config.block_out_channels[-1] - elif name.startswith("up_block"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(model.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = model.config.block_out_channels[block_id] - elif name.startswith("transformer_in"): - # Note that the `8 * ...` comes from: https://github.com/huggingface/diffusers/blob/7139f0e874f10b2463caa8cbd585762a309d12d6/src/diffusers/models/unet_3d_condition.py#L148 - hidden_size = 8 * model.config.attention_head_dim - - lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) - lora_attn_procs[name] = lora_attn_procs[name].to(model.device) - + lora_attn_procs = create_lora_layers(model, mock_weights=False) model.set_attn_processor(lora_attn_procs) # Saving as torch, properly reloads with directly filename with tempfile.TemporaryDirectory() as tmpdirname: @@ -373,28 +328,7 @@ def test_lora_save_torch_force_load_safetensors_error(self): model = self.model_class(**init_dict) model.to(torch_device) - lora_attn_procs = {} - for name in model.attn_processors.keys(): - has_cross_attention = name.endswith("attn2.processor") and not ( - name.startswith("transformer_in") or "temp_attentions" in name.split(".") - ) - cross_attention_dim = model.config.cross_attention_dim if has_cross_attention else None - - if name.startswith("mid_block"): - hidden_size = model.config.block_out_channels[-1] - elif name.startswith("up_block"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(model.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = model.config.block_out_channels[block_id] - elif name.startswith("transformer_in"): - # Note that the `8 * ...` comes from: https://github.com/huggingface/diffusers/blob/7139f0e874f10b2463caa8cbd585762a309d12d6/src/diffusers/models/unet_3d_condition.py#L148 - hidden_size = 8 * model.config.attention_head_dim - - lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) - lora_attn_procs[name] = lora_attn_procs[name].to(model.device) - + lora_attn_procs = create_lora_layers(model, mock_weights=False) model.set_attn_processor(lora_attn_procs) # Saving as torch, properly reloads with directly filename with tempfile.TemporaryDirectory() as tmpdirname: