Skip to content

Commit 2e4dc3e

Browse files
authored
[LoRA] add: test to check if peft loras are loadable in non-peft envs. (#6400)
* add: test to check if peft loras are loadable in non-peft envs. * add torch_device approrpiately. * fix: get_dummy_inputs(). * test logits. * rename * debug * debug * fix: generator * new assertion values after fixing the seed. * shape * remove print statements and settle this. * to update values. * change values when lora config is initialized under a fixed seed. * update colab link * update notebook link * sanity restored by getting the exact same values without peft.
1 parent 3e2961f commit 2e4dc3e

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# coding=utf-8
2+
# Copyright 2023 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import unittest
17+
18+
import numpy as np
19+
import torch
20+
21+
from diffusers import DiffusionPipeline
22+
from diffusers.utils.testing_utils import torch_device
23+
24+
25+
class PEFTLoRALoading(unittest.TestCase):
26+
def get_dummy_inputs(self):
27+
pipeline_inputs = {
28+
"prompt": "A painting of a squirrel eating a burger",
29+
"num_inference_steps": 2,
30+
"guidance_scale": 6.0,
31+
"output_type": "np",
32+
"generator": torch.manual_seed(0),
33+
}
34+
return pipeline_inputs
35+
36+
def test_stable_diffusion_peft_lora_loading_in_non_peft(self):
37+
sd_pipe = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
38+
# This LoRA was obtained using similarly as how it's done in the training scripts.
39+
# For details on how the LoRA was obtained, refer to:
40+
# https://hf.co/datasets/diffusers/notebooks/blob/main/check_logits_with_serialization_peft_lora.py
41+
sd_pipe.load_lora_weights("hf-internal-testing/tiny-sd-lora-peft")
42+
43+
inputs = self.get_dummy_inputs()
44+
outputs = sd_pipe(**inputs).images
45+
46+
predicted_slice = outputs[0, -3:, -3:, -1].flatten()
47+
expected_slice = np.array([0.5396, 0.5707, 0.477, 0.4665, 0.5419, 0.4594, 0.4857, 0.4741, 0.4804])
48+
49+
self.assertTrue(outputs.shape == (1, 64, 64, 3))
50+
assert np.allclose(expected_slice, predicted_slice, atol=1e-3, rtol=1e-3)
51+
52+
def test_stable_diffusion_xl_peft_lora_loading_in_non_peft(self):
53+
sd_pipe = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sdxl-pipe").to(torch_device)
54+
# This LoRA was obtained using similarly as how it's done in the training scripts.
55+
sd_pipe.load_lora_weights("hf-internal-testing/tiny-sdxl-lora-peft")
56+
57+
inputs = self.get_dummy_inputs()
58+
outputs = sd_pipe(**inputs).images
59+
60+
predicted_slice = outputs[0, -3:, -3:, -1].flatten()
61+
expected_slice = np.array([0.613, 0.5566, 0.54, 0.4162, 0.4042, 0.4596, 0.5374, 0.5286, 0.5038])
62+
63+
self.assertTrue(outputs.shape == (1, 64, 64, 3))
64+
assert np.allclose(expected_slice, predicted_slice, atol=1e-3, rtol=1e-3)

0 commit comments

Comments
 (0)