diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index 9f8ee2a22aab..2c86171610bf 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -251,7 +251,9 @@ def forward( encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, ).sample - hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample + hidden_states = temp_attn( + hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs + ).sample hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) @@ -376,7 +378,9 @@ def forward( encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, ).sample - hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample + hidden_states = temp_attn( + hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs + ).sample output_states += (hidden_states,) @@ -587,7 +591,9 @@ def forward( encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, ).sample - hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample + hidden_states = temp_attn( + hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs + ).sample if self.upsamplers is not None: for upsampler in self.upsamplers: diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index ec8865f31031..6fb5dfa30ebf 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -20,6 +20,7 @@ import torch.utils.checkpoint from ..configuration_utils import ConfigMixin, register_to_config +from ..loaders import UNet2DConditionLoadersMixin from ..utils import BaseOutput, logging from .attention_processor import AttentionProcessor, AttnProcessor from .embeddings import TimestepEmbedding, Timesteps @@ -50,7 +51,7 @@ class UNet3DConditionOutput(BaseOutput): sample: torch.FloatTensor -class UNet3DConditionModel(ModelMixin, ConfigMixin): +class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep and returns sample shaped output. @@ -465,7 +466,9 @@ def forward( sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) sample = self.conv_in(sample) - sample = self.transformer_in(sample, num_frames=num_frames).sample + sample = self.transformer_in( + sample, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs + ).sample # 3. down down_block_res_samples = (sample,) diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index c0cb9d3d8ebd..17e08e0a426e 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -41,7 +41,7 @@ torch.backends.cuda.matmul.allow_tf32 = False -def create_lora_layers(model): +def create_lora_layers(model, mock_weights: bool = True): lora_attn_procs = {} for name in model.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim @@ -57,12 +57,13 @@ def create_lora_layers(model): lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) lora_attn_procs[name] = lora_attn_procs[name].to(model.device) - # add 1 to weights to mock trained weights - with torch.no_grad(): - lora_attn_procs[name].to_q_lora.up.weight += 1 - lora_attn_procs[name].to_k_lora.up.weight += 1 - lora_attn_procs[name].to_v_lora.up.weight += 1 - lora_attn_procs[name].to_out_lora.up.weight += 1 + if mock_weights: + # add 1 to weights to mock trained weights + with torch.no_grad(): + lora_attn_procs[name].to_q_lora.up.weight += 1 + lora_attn_procs[name].to_k_lora.up.weight += 1 + lora_attn_procs[name].to_v_lora.up.weight += 1 + lora_attn_procs[name].to_out_lora.up.weight += 1 return lora_attn_procs @@ -378,26 +379,7 @@ def test_lora_processors(self): with torch.no_grad(): sample1 = model(**inputs_dict).sample - lora_attn_procs = {} - for name in model.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim - if name.startswith("mid_block"): - hidden_size = model.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(model.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = model.config.block_out_channels[block_id] - - lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) - - # add 1 to weights to mock trained weights - with torch.no_grad(): - lora_attn_procs[name].to_q_lora.up.weight += 1 - lora_attn_procs[name].to_k_lora.up.weight += 1 - lora_attn_procs[name].to_v_lora.up.weight += 1 - lora_attn_procs[name].to_out_lora.up.weight += 1 + lora_attn_procs = create_lora_layers(model) # make sure we can set a list of attention processors model.set_attn_processor(lora_attn_procs) @@ -465,28 +447,7 @@ def test_lora_save_load_safetensors(self): with torch.no_grad(): old_sample = model(**inputs_dict).sample - lora_attn_procs = {} - for name in model.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim - if name.startswith("mid_block"): - hidden_size = model.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(model.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = model.config.block_out_channels[block_id] - - lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) - lora_attn_procs[name] = lora_attn_procs[name].to(model.device) - - # add 1 to weights to mock trained weights - with torch.no_grad(): - lora_attn_procs[name].to_q_lora.up.weight += 1 - lora_attn_procs[name].to_k_lora.up.weight += 1 - lora_attn_procs[name].to_v_lora.up.weight += 1 - lora_attn_procs[name].to_out_lora.up.weight += 1 - + lora_attn_procs = create_lora_layers(model) model.set_attn_processor(lora_attn_procs) with torch.no_grad(): @@ -518,21 +479,7 @@ def test_lora_save_safetensors_load_torch(self): model = self.model_class(**init_dict) model.to(torch_device) - lora_attn_procs = {} - for name in model.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim - if name.startswith("mid_block"): - hidden_size = model.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(model.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = model.config.block_out_channels[block_id] - - lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) - lora_attn_procs[name] = lora_attn_procs[name].to(model.device) - + lora_attn_procs = create_lora_layers(model, mock_weights=False) model.set_attn_processor(lora_attn_procs) # Saving as torch, properly reloads with directly filename with tempfile.TemporaryDirectory() as tmpdirname: @@ -553,21 +500,7 @@ def test_lora_save_torch_force_load_safetensors_error(self): model = self.model_class(**init_dict) model.to(torch_device) - lora_attn_procs = {} - for name in model.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim - if name.startswith("mid_block"): - hidden_size = model.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(model.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = model.config.block_out_channels[block_id] - - lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) - lora_attn_procs[name] = lora_attn_procs[name].to(model.device) - + lora_attn_procs = create_lora_layers(model, mock_weights=False) model.set_attn_processor(lora_attn_procs) # Saving as torch, properly reloads with directly filename with tempfile.TemporaryDirectory() as tmpdirname: diff --git a/tests/models/test_models_unet_3d_condition.py b/tests/models/test_models_unet_3d_condition.py index 5a0d74a3ea5a..c552b503af05 100644 --- a/tests/models/test_models_unet_3d_condition.py +++ b/tests/models/test_models_unet_3d_condition.py @@ -13,13 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import tempfile import unittest import numpy as np import torch from diffusers.models import ModelMixin, UNet3DConditionModel -from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.models.attention_processor import AttnProcessor, LoRAAttnProcessor from diffusers.utils import ( floats_tensor, logging, @@ -35,10 +37,13 @@ torch.backends.cuda.matmul.allow_tf32 = False -def create_lora_layers(model): +def create_lora_layers(model, mock_weights: bool = True): lora_attn_procs = {} for name in model.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim + has_cross_attention = name.endswith("attn2.processor") and not ( + name.startswith("transformer_in") or "temp_attentions" in name.split(".") + ) + cross_attention_dim = model.config.cross_attention_dim if has_cross_attention else None if name.startswith("mid_block"): hidden_size = model.config.block_out_channels[-1] elif name.startswith("up_blocks"): @@ -47,16 +52,20 @@ def create_lora_layers(model): elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = model.config.block_out_channels[block_id] + elif name.startswith("transformer_in"): + # Note that the `8 * ...` comes from: https://github.com/huggingface/diffusers/blob/7139f0e874f10b2463caa8cbd585762a309d12d6/src/diffusers/models/unet_3d_condition.py#L148 + hidden_size = 8 * model.config.attention_head_dim lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) lora_attn_procs[name] = lora_attn_procs[name].to(model.device) - # add 1 to weights to mock trained weights - with torch.no_grad(): - lora_attn_procs[name].to_q_lora.up.weight += 1 - lora_attn_procs[name].to_k_lora.up.weight += 1 - lora_attn_procs[name].to_v_lora.up.weight += 1 - lora_attn_procs[name].to_out_lora.up.weight += 1 + if mock_weights: + # add 1 to weights to mock trained weights + with torch.no_grad(): + lora_attn_procs[name].to_q_lora.up.weight += 1 + lora_attn_procs[name].to_k_lora.up.weight += 1 + lora_attn_procs[name].to_v_lora.up.weight += 1 + lora_attn_procs[name].to_out_lora.up.weight += 1 return lora_attn_procs @@ -190,23 +199,173 @@ def test_model_attention_slicing(self): output = model(**inputs_dict) assert output is not None - # (`attn_processors`) needs to be implemented in this model for this test. - # def test_lora_processors(self): + def test_lora_processors(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = 8 + + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + sample1 = model(**inputs_dict).sample + + lora_attn_procs = create_lora_layers(model) + + # make sure we can set a list of attention processors + model.set_attn_processor(lora_attn_procs) + model.to(torch_device) + + # test that attn processors can be set to itself + model.set_attn_processor(model.attn_processors) + + with torch.no_grad(): + sample2 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample + sample3 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + sample4 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + assert (sample1 - sample2).abs().max() < 1e-4 + assert (sample3 - sample4).abs().max() < 1e-4 + + # sample 2 and sample 3 should be different + assert (sample2 - sample3).abs().max() > 1e-4 + + def test_lora_save_load(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = 8 + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + old_sample = model(**inputs_dict).sample + + lora_attn_procs = create_lora_layers(model) + model.set_attn_processor(lora_attn_procs) + + with torch.no_grad(): + sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_attn_procs(tmpdirname) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + torch.manual_seed(0) + new_model = self.model_class(**init_dict) + new_model.to(torch_device) + new_model.load_attn_procs(tmpdirname) + + with torch.no_grad(): + new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + assert (sample - new_sample).abs().max() < 1e-4 + + # LoRA and no LoRA should NOT be the same + assert (sample - old_sample).abs().max() > 1e-4 + + def test_lora_save_load_safetensors(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = 8 + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + old_sample = model(**inputs_dict).sample + + lora_attn_procs = create_lora_layers(model) + model.set_attn_processor(lora_attn_procs) + + with torch.no_grad(): + sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_attn_procs(tmpdirname, safe_serialization=True) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + torch.manual_seed(0) + new_model = self.model_class(**init_dict) + new_model.to(torch_device) + new_model.load_attn_procs(tmpdirname) + + with torch.no_grad(): + new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + assert (sample - new_sample).abs().max() < 1e-4 + + # LoRA and no LoRA should NOT be the same + assert (sample - old_sample).abs().max() > 1e-4 + + def test_lora_save_safetensors_load_torch(self): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = 8 + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + lora_attn_procs = create_lora_layers(model, mock_weights=False) + model.set_attn_processor(lora_attn_procs) + # Saving as torch, properly reloads with directly filename + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_attn_procs(tmpdirname) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + torch.manual_seed(0) + new_model = self.model_class(**init_dict) + new_model.to(torch_device) + new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.bin") + + def test_lora_save_torch_force_load_safetensors_error(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = 8 - # (`attn_processors`) needs to be implemented in this model for this test. - # def test_lora_save_load(self): + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) - # (`attn_processors`) needs to be implemented for this test in the model. - # def test_lora_save_load_safetensors(self): + lora_attn_procs = create_lora_layers(model, mock_weights=False) + model.set_attn_processor(lora_attn_procs) + # Saving as torch, properly reloads with directly filename + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_attn_procs(tmpdirname) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + torch.manual_seed(0) + new_model = self.model_class(**init_dict) + new_model.to(torch_device) + with self.assertRaises(IOError) as e: + new_model.load_attn_procs(tmpdirname, use_safetensors=True) + self.assertIn("Error no file named pytorch_lora_weights.safetensors", str(e.exception)) + + def test_lora_on_off(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - # (`attn_processors`) needs to be implemented for this test in the model. - # def test_lora_save_safetensors_load_torch(self): + init_dict["attention_head_dim"] = 8 - # (`attn_processors`) needs to be implemented for this test. - # def test_lora_save_torch_force_load_safetensors_error(self): + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + old_sample = model(**inputs_dict).sample + + lora_attn_procs = create_lora_layers(model) + model.set_attn_processor(lora_attn_procs) + + with torch.no_grad(): + sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample + + model.set_attn_processor(AttnProcessor()) + + with torch.no_grad(): + new_sample = model(**inputs_dict).sample - # (`attn_processors`) needs to be added for this test. - # def test_lora_on_off(self): + assert (sample - new_sample).abs().max() < 1e-4 + assert (sample - old_sample).abs().max() < 1e-4 @unittest.skipIf( torch_device != "cuda" or not is_xformers_available(),