Skip to content

Commit 80ac0d4

Browse files
committed
Attept 1 to fix tests.
1 parent c4210ae commit 80ac0d4

File tree

1 file changed

+32
-2
lines changed

1 file changed

+32
-2
lines changed

tests/lora/test_lora_layers_af.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,19 @@
1515
import sys
1616
import unittest
1717

18+
import torch
1819
from transformers import AutoTokenizer, T5EncoderModel
1920

2021
from diffusers import (
2122
AuraFlowPipeline,
23+
AuraFlowTransformer2DModel,
2224
FlowMatchEulerDiscreteScheduler,
2325
)
24-
from diffusers.utils.testing_utils import is_peft_available, require_peft_backend
26+
from diffusers.utils.testing_utils import (
27+
floats_tensor,
28+
is_peft_available,
29+
require_peft_backend,
30+
)
2531

2632

2733
if is_peft_available():
@@ -49,8 +55,9 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
4955
"joint_attention_dim": 32,
5056
"caption_projection_dim": 32,
5157
"out_channels": 4,
52-
"pos_embed_max_size": 32,
58+
"pos_embed_max_size": 64,
5359
}
60+
transformer_cls = AuraFlowTransformer2DModel
5461
tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
5562
text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"
5663

@@ -71,3 +78,26 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
7178
@property
7279
def output_shape(self):
7380
return (1, 64, 64, 3)
81+
82+
def get_dummy_inputs(self, with_generator=True):
83+
batch_size = 1
84+
sequence_length = 10
85+
num_channels = 4
86+
sizes = (32, 32)
87+
88+
generator = torch.manual_seed(0)
89+
noise = floats_tensor((batch_size, num_channels) + sizes)
90+
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
91+
92+
pipeline_inputs = {
93+
"prompt": "A painting of a squirrel eating a burger",
94+
"num_inference_steps": 4,
95+
"guidance_scale": 0.0,
96+
"height": 8,
97+
"width": 8,
98+
"output_type": "np",
99+
}
100+
if with_generator:
101+
pipeline_inputs.update({"generator": generator})
102+
103+
return noise, input_ids, pipeline_inputs

0 commit comments

Comments
 (0)