Skip to content

Commit 6101d4a

Browse files
committed
Add LoraLoaderMixin and update prepare_image_latents
1 parent baafe02 commit 6101d4a

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The HuggingFace Team. All rights reserved.
1+
# Copyright 2023 DiffEdit Authors and Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -23,7 +23,7 @@
2323
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
2424

2525
from ...configuration_utils import FrozenDict
26-
from ...loaders import TextualInversionLoaderMixin
26+
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
2727
from ...models import AutoencoderKL, UNet2DConditionModel
2828
from ...schedulers import DDIMInverseScheduler, KarrasDiffusionSchedulers
2929
from ...utils import (
@@ -230,13 +230,20 @@ def preprocess_mask(mask, batch_size: int = 1):
230230
return mask
231231

232232

233-
class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
233+
class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
234234
r"""
235235
Pipeline for text-guided image inpainting using Stable Diffusion using DiffEdit. *This is an experimental feature*.
236236
237237
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
238238
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
239239
240+
In addition the pipeline inherits the following loading methods:
241+
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
242+
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
243+
244+
as well as the following saving methods:
245+
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
246+
240247
Args:
241248
vae ([`AutoencoderKL`]):
242249
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
@@ -771,6 +778,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
771778
latents = latents * self.scheduler.init_noise_sigma
772779
return latents
773780

781+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.StableDiffusionPix2PixZeroPipeline.prepare_image_latents
774782
def prepare_image_latents(self, image, batch_size, dtype, device, generator=None):
775783
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
776784
raise ValueError(

0 commit comments

Comments
 (0)