|
22 | 22 |
|
23 | 23 | from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
24 | 24 | from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
|
25 |
| -from diffusers.models.attention_processor import LoRAAttnProcessor |
| 25 | +from diffusers.models.attention_processor import ( |
| 26 | + Attention, |
| 27 | + AttnProcessor, |
| 28 | + AttnProcessor2_0, |
| 29 | + LoRAAttnProcessor, |
| 30 | + LoRAXFormersAttnProcessor, |
| 31 | + XFormersAttnProcessor, |
| 32 | +) |
26 | 33 | from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, floats_tensor, torch_device
|
27 | 34 |
|
28 | 35 |
|
@@ -212,3 +219,90 @@ def test_lora_save_load_legacy(self):
|
212 | 219 |
|
213 | 220 | # Outputs shouldn't match.
|
214 | 221 | self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))
|
| 222 | + |
| 223 | + def create_lora_weight_file(self, tmpdirname): |
| 224 | + _, lora_components = self.get_dummy_components() |
| 225 | + LoraLoaderMixin.save_lora_weights( |
| 226 | + save_directory=tmpdirname, |
| 227 | + unet_lora_layers=lora_components["unet_lora_layers"], |
| 228 | + text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], |
| 229 | + ) |
| 230 | + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) |
| 231 | + |
| 232 | + def test_lora_unet_attn_processors(self): |
| 233 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 234 | + self.create_lora_weight_file(tmpdirname) |
| 235 | + |
| 236 | + pipeline_components, _ = self.get_dummy_components() |
| 237 | + sd_pipe = StableDiffusionPipeline(**pipeline_components) |
| 238 | + sd_pipe = sd_pipe.to(torch_device) |
| 239 | + sd_pipe.set_progress_bar_config(disable=None) |
| 240 | + |
| 241 | + # check if vanilla attention processors are used |
| 242 | + for _, module in sd_pipe.unet.named_modules(): |
| 243 | + if isinstance(module, Attention): |
| 244 | + self.assertIsInstance(module.processor, (AttnProcessor, AttnProcessor2_0)) |
| 245 | + |
| 246 | + # load LoRA weight file |
| 247 | + sd_pipe.load_lora_weights(tmpdirname) |
| 248 | + |
| 249 | + # check if lora attention processors are used |
| 250 | + for _, module in sd_pipe.unet.named_modules(): |
| 251 | + if isinstance(module, Attention): |
| 252 | + self.assertIsInstance(module.processor, LoRAAttnProcessor) |
| 253 | + |
| 254 | + @unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU") |
| 255 | + def test_lora_unet_attn_processors_with_xformers(self): |
| 256 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 257 | + self.create_lora_weight_file(tmpdirname) |
| 258 | + |
| 259 | + pipeline_components, _ = self.get_dummy_components() |
| 260 | + sd_pipe = StableDiffusionPipeline(**pipeline_components) |
| 261 | + sd_pipe = sd_pipe.to(torch_device) |
| 262 | + sd_pipe.set_progress_bar_config(disable=None) |
| 263 | + |
| 264 | + # enable XFormers |
| 265 | + sd_pipe.enable_xformers_memory_efficient_attention() |
| 266 | + |
| 267 | + # check if xFormers attention processors are used |
| 268 | + for _, module in sd_pipe.unet.named_modules(): |
| 269 | + if isinstance(module, Attention): |
| 270 | + self.assertIsInstance(module.processor, XFormersAttnProcessor) |
| 271 | + |
| 272 | + # load LoRA weight file |
| 273 | + sd_pipe.load_lora_weights(tmpdirname) |
| 274 | + |
| 275 | + # check if lora attention processors are used |
| 276 | + for _, module in sd_pipe.unet.named_modules(): |
| 277 | + if isinstance(module, Attention): |
| 278 | + self.assertIsInstance(module.processor, LoRAXFormersAttnProcessor) |
| 279 | + |
| 280 | + @unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU") |
| 281 | + def test_lora_save_load_with_xformers(self): |
| 282 | + pipeline_components, lora_components = self.get_dummy_components() |
| 283 | + sd_pipe = StableDiffusionPipeline(**pipeline_components) |
| 284 | + sd_pipe = sd_pipe.to(torch_device) |
| 285 | + sd_pipe.set_progress_bar_config(disable=None) |
| 286 | + |
| 287 | + noise, input_ids, pipeline_inputs = self.get_dummy_inputs() |
| 288 | + |
| 289 | + # enable XFormers |
| 290 | + sd_pipe.enable_xformers_memory_efficient_attention() |
| 291 | + |
| 292 | + original_images = sd_pipe(**pipeline_inputs).images |
| 293 | + orig_image_slice = original_images[0, -3:, -3:, -1] |
| 294 | + |
| 295 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 296 | + LoraLoaderMixin.save_lora_weights( |
| 297 | + save_directory=tmpdirname, |
| 298 | + unet_lora_layers=lora_components["unet_lora_layers"], |
| 299 | + text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], |
| 300 | + ) |
| 301 | + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) |
| 302 | + sd_pipe.load_lora_weights(tmpdirname) |
| 303 | + |
| 304 | + lora_images = sd_pipe(**pipeline_inputs).images |
| 305 | + lora_image_slice = lora_images[0, -3:, -3:, -1] |
| 306 | + |
| 307 | + # Outputs shouldn't match. |
| 308 | + self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) |
0 commit comments