Skip to content

Commit 9d7c08f

Browse files
aandywpatrickvonplatensayakpaul
authored
[WIP] implement rest of the test cases (LoRA tests) (#2824)
* inital commit for lora test cases * help a bit with lora for 3d * fixed lora tests * replaced redundant code --------- Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent dc27750 commit 9d7c08f

File tree

4 files changed

+206
-105
lines changed

4 files changed

+206
-105
lines changed

src/diffusers/models/unet_3d_blocks.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,9 @@ def forward(
251251
encoder_hidden_states=encoder_hidden_states,
252252
cross_attention_kwargs=cross_attention_kwargs,
253253
).sample
254-
hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample
254+
hidden_states = temp_attn(
255+
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
256+
).sample
255257
hidden_states = resnet(hidden_states, temb)
256258
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
257259

@@ -376,7 +378,9 @@ def forward(
376378
encoder_hidden_states=encoder_hidden_states,
377379
cross_attention_kwargs=cross_attention_kwargs,
378380
).sample
379-
hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample
381+
hidden_states = temp_attn(
382+
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
383+
).sample
380384

381385
output_states += (hidden_states,)
382386

@@ -587,7 +591,9 @@ def forward(
587591
encoder_hidden_states=encoder_hidden_states,
588592
cross_attention_kwargs=cross_attention_kwargs,
589593
).sample
590-
hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample
594+
hidden_states = temp_attn(
595+
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
596+
).sample
591597

592598
if self.upsamplers is not None:
593599
for upsampler in self.upsamplers:

src/diffusers/models/unet_3d_condition.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch.utils.checkpoint
2121

2222
from ..configuration_utils import ConfigMixin, register_to_config
23+
from ..loaders import UNet2DConditionLoadersMixin
2324
from ..utils import BaseOutput, logging
2425
from .attention_processor import AttentionProcessor, AttnProcessor
2526
from .embeddings import TimestepEmbedding, Timesteps
@@ -50,7 +51,7 @@ class UNet3DConditionOutput(BaseOutput):
5051
sample: torch.FloatTensor
5152

5253

53-
class UNet3DConditionModel(ModelMixin, ConfigMixin):
54+
class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
5455
r"""
5556
UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
5657
and returns sample shaped output.
@@ -465,7 +466,9 @@ def forward(
465466
sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
466467
sample = self.conv_in(sample)
467468

468-
sample = self.transformer_in(sample, num_frames=num_frames).sample
469+
sample = self.transformer_in(
470+
sample, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
471+
).sample
469472

470473
# 3. down
471474
down_block_res_samples = (sample,)

tests/models/test_models_unet_2d_condition.py

Lines changed: 12 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
torch.backends.cuda.matmul.allow_tf32 = False
4242

4343

44-
def create_lora_layers(model):
44+
def create_lora_layers(model, mock_weights: bool = True):
4545
lora_attn_procs = {}
4646
for name in model.attn_processors.keys():
4747
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
@@ -57,12 +57,13 @@ def create_lora_layers(model):
5757
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
5858
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
5959

60-
# add 1 to weights to mock trained weights
61-
with torch.no_grad():
62-
lora_attn_procs[name].to_q_lora.up.weight += 1
63-
lora_attn_procs[name].to_k_lora.up.weight += 1
64-
lora_attn_procs[name].to_v_lora.up.weight += 1
65-
lora_attn_procs[name].to_out_lora.up.weight += 1
60+
if mock_weights:
61+
# add 1 to weights to mock trained weights
62+
with torch.no_grad():
63+
lora_attn_procs[name].to_q_lora.up.weight += 1
64+
lora_attn_procs[name].to_k_lora.up.weight += 1
65+
lora_attn_procs[name].to_v_lora.up.weight += 1
66+
lora_attn_procs[name].to_out_lora.up.weight += 1
6667

6768
return lora_attn_procs
6869

@@ -378,26 +379,7 @@ def test_lora_processors(self):
378379
with torch.no_grad():
379380
sample1 = model(**inputs_dict).sample
380381

381-
lora_attn_procs = {}
382-
for name in model.attn_processors.keys():
383-
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
384-
if name.startswith("mid_block"):
385-
hidden_size = model.config.block_out_channels[-1]
386-
elif name.startswith("up_blocks"):
387-
block_id = int(name[len("up_blocks.")])
388-
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
389-
elif name.startswith("down_blocks"):
390-
block_id = int(name[len("down_blocks.")])
391-
hidden_size = model.config.block_out_channels[block_id]
392-
393-
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
394-
395-
# add 1 to weights to mock trained weights
396-
with torch.no_grad():
397-
lora_attn_procs[name].to_q_lora.up.weight += 1
398-
lora_attn_procs[name].to_k_lora.up.weight += 1
399-
lora_attn_procs[name].to_v_lora.up.weight += 1
400-
lora_attn_procs[name].to_out_lora.up.weight += 1
382+
lora_attn_procs = create_lora_layers(model)
401383

402384
# make sure we can set a list of attention processors
403385
model.set_attn_processor(lora_attn_procs)
@@ -465,28 +447,7 @@ def test_lora_save_load_safetensors(self):
465447
with torch.no_grad():
466448
old_sample = model(**inputs_dict).sample
467449

468-
lora_attn_procs = {}
469-
for name in model.attn_processors.keys():
470-
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
471-
if name.startswith("mid_block"):
472-
hidden_size = model.config.block_out_channels[-1]
473-
elif name.startswith("up_blocks"):
474-
block_id = int(name[len("up_blocks.")])
475-
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
476-
elif name.startswith("down_blocks"):
477-
block_id = int(name[len("down_blocks.")])
478-
hidden_size = model.config.block_out_channels[block_id]
479-
480-
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
481-
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
482-
483-
# add 1 to weights to mock trained weights
484-
with torch.no_grad():
485-
lora_attn_procs[name].to_q_lora.up.weight += 1
486-
lora_attn_procs[name].to_k_lora.up.weight += 1
487-
lora_attn_procs[name].to_v_lora.up.weight += 1
488-
lora_attn_procs[name].to_out_lora.up.weight += 1
489-
450+
lora_attn_procs = create_lora_layers(model)
490451
model.set_attn_processor(lora_attn_procs)
491452

492453
with torch.no_grad():
@@ -518,21 +479,7 @@ def test_lora_save_safetensors_load_torch(self):
518479
model = self.model_class(**init_dict)
519480
model.to(torch_device)
520481

521-
lora_attn_procs = {}
522-
for name in model.attn_processors.keys():
523-
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
524-
if name.startswith("mid_block"):
525-
hidden_size = model.config.block_out_channels[-1]
526-
elif name.startswith("up_blocks"):
527-
block_id = int(name[len("up_blocks.")])
528-
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
529-
elif name.startswith("down_blocks"):
530-
block_id = int(name[len("down_blocks.")])
531-
hidden_size = model.config.block_out_channels[block_id]
532-
533-
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
534-
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
535-
482+
lora_attn_procs = create_lora_layers(model, mock_weights=False)
536483
model.set_attn_processor(lora_attn_procs)
537484
# Saving as torch, properly reloads with directly filename
538485
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -553,21 +500,7 @@ def test_lora_save_torch_force_load_safetensors_error(self):
553500
model = self.model_class(**init_dict)
554501
model.to(torch_device)
555502

556-
lora_attn_procs = {}
557-
for name in model.attn_processors.keys():
558-
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
559-
if name.startswith("mid_block"):
560-
hidden_size = model.config.block_out_channels[-1]
561-
elif name.startswith("up_blocks"):
562-
block_id = int(name[len("up_blocks.")])
563-
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
564-
elif name.startswith("down_blocks"):
565-
block_id = int(name[len("down_blocks.")])
566-
hidden_size = model.config.block_out_channels[block_id]
567-
568-
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
569-
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
570-
503+
lora_attn_procs = create_lora_layers(model, mock_weights=False)
571504
model.set_attn_processor(lora_attn_procs)
572505
# Saving as torch, properly reloads with directly filename
573506
with tempfile.TemporaryDirectory() as tmpdirname:

0 commit comments

Comments
 (0)