Skip to content

Commit bdab2c7

Browse files
Add draft for lora text encoder scale (huggingface#3626)
* Add draft for lora text encoder scale * Improve naming * fix: training dreambooth lora script. * Apply suggestions from code review * Update examples/dreambooth/train_dreambooth_lora.py * Apply suggestions from code review * Apply suggestions from code review * add lora mixin when fit * add lora mixin when fit * add lora mixin when fit * fix more * fix more --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 1fc4f76 commit bdab2c7

23 files changed

+331
-40
lines changed

loaders.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
852852
weight_name = kwargs.pop("weight_name", None)
853853
use_safetensors = kwargs.pop("use_safetensors", None)
854854

855+
# set lora scale to a reasonable default
856+
self._lora_scale = 1.0
857+
855858
if use_safetensors and not is_safetensors_available():
856859
raise ValueError(
857860
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
@@ -953,6 +956,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
953956
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
954957
warnings.warn(warn_message)
955958

959+
@property
960+
def lora_scale(self) -> float:
961+
# property function that returns the lora scale which can be set at run time by the pipeline.
962+
# if _lora_scale has not been set, return 1
963+
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
964+
956965
@property
957966
def text_encoder_lora_attn_procs(self):
958967
if hasattr(self, "_text_encoder_lora_attn_procs"):
@@ -1000,7 +1009,8 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
10001009
# for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
10011010
def make_new_forward(old_forward, lora_layer):
10021011
def new_forward(x):
1003-
return old_forward(x) + lora_layer(x)
1012+
result = old_forward(x) + self.lora_scale * lora_layer(x)
1013+
return result
10041014

10051015
return new_forward
10061016

pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from ...configuration_utils import FrozenDict
2626
from ...image_processor import VaeImageProcessor
27-
from ...loaders import TextualInversionLoaderMixin
27+
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
2828
from ...models import AutoencoderKL, UNet2DConditionModel
2929
from ...schedulers import KarrasDiffusionSchedulers
3030
from ...utils import deprecate, logging, randn_tensor, replace_example_docstring
@@ -52,7 +52,7 @@
5252

5353

5454
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
55-
class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
55+
class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
5656
r"""
5757
Pipeline for text-to-image generation using Alt Diffusion.
5858
@@ -291,6 +291,7 @@ def _encode_prompt(
291291
negative_prompt=None,
292292
prompt_embeds: Optional[torch.FloatTensor] = None,
293293
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
294+
lora_scale: Optional[float] = None,
294295
):
295296
r"""
296297
Encodes the prompt into text encoder hidden states.
@@ -315,7 +316,14 @@ def _encode_prompt(
315316
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
316317
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
317318
argument.
319+
lora_scale (`float`, *optional*):
320+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
318321
"""
322+
# set lora scale so that monkey patched LoRA
323+
# function of text encoder can correctly access it
324+
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
325+
self._lora_scale = lora_scale
326+
319327
if prompt is not None and isinstance(prompt, str):
320328
batch_size = 1
321329
elif prompt is not None and isinstance(prompt, list):
@@ -653,6 +661,9 @@ def __call__(
653661
do_classifier_free_guidance = guidance_scale > 1.0
654662

655663
# 3. Encode input prompt
664+
text_encoder_lora_scale = (
665+
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
666+
)
656667
prompt_embeds = self._encode_prompt(
657668
prompt,
658669
device,
@@ -661,6 +672,7 @@ def __call__(
661672
negative_prompt,
662673
prompt_embeds=prompt_embeds,
663674
negative_prompt_embeds=negative_prompt_embeds,
675+
lora_scale=text_encoder_lora_scale,
664676
)
665677

666678
# 4. Prepare timesteps

pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from ...configuration_utils import FrozenDict
2828
from ...image_processor import VaeImageProcessor
29-
from ...loaders import TextualInversionLoaderMixin
29+
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
3030
from ...models import AutoencoderKL, UNet2DConditionModel
3131
from ...schedulers import KarrasDiffusionSchedulers
3232
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring
@@ -95,7 +95,7 @@ def preprocess(image):
9595

9696

9797
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
98-
class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
98+
class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
9999
r"""
100100
Pipeline for text-guided image to image generation using Alt Diffusion.
101101
@@ -302,6 +302,7 @@ def _encode_prompt(
302302
negative_prompt=None,
303303
prompt_embeds: Optional[torch.FloatTensor] = None,
304304
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
305+
lora_scale: Optional[float] = None,
305306
):
306307
r"""
307308
Encodes the prompt into text encoder hidden states.
@@ -326,7 +327,14 @@ def _encode_prompt(
326327
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
327328
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
328329
argument.
330+
lora_scale (`float`, *optional*):
331+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
329332
"""
333+
# set lora scale so that monkey patched LoRA
334+
# function of text encoder can correctly access it
335+
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
336+
self._lora_scale = lora_scale
337+
330338
if prompt is not None and isinstance(prompt, str):
331339
batch_size = 1
332340
elif prompt is not None and isinstance(prompt, list):
@@ -706,6 +714,9 @@ def __call__(
706714
do_classifier_free_guidance = guidance_scale > 1.0
707715

708716
# 3. Encode input prompt
717+
text_encoder_lora_scale = (
718+
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
719+
)
709720
prompt_embeds = self._encode_prompt(
710721
prompt,
711722
device,
@@ -714,6 +725,7 @@ def __call__(
714725
negative_prompt,
715726
prompt_embeds=prompt_embeds,
716727
negative_prompt_embeds=negative_prompt_embeds,
728+
lora_scale=text_encoder_lora_scale,
717729
)
718730

719731
# 4. Preprocess image

pipelines/controlnet/pipeline_controlnet.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
2626

2727
from ...image_processor import VaeImageProcessor
28-
from ...loaders import TextualInversionLoaderMixin
28+
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
2929
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
3030
from ...schedulers import KarrasDiffusionSchedulers
3131
from ...utils import (
@@ -91,7 +91,7 @@
9191
"""
9292

9393

94-
class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
94+
class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
9595
r"""
9696
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
9797
@@ -291,6 +291,7 @@ def _encode_prompt(
291291
negative_prompt=None,
292292
prompt_embeds: Optional[torch.FloatTensor] = None,
293293
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
294+
lora_scale: Optional[float] = None,
294295
):
295296
r"""
296297
Encodes the prompt into text encoder hidden states.
@@ -315,7 +316,14 @@ def _encode_prompt(
315316
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
316317
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
317318
argument.
319+
lora_scale (`float`, *optional*):
320+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
318321
"""
322+
# set lora scale so that monkey patched LoRA
323+
# function of text encoder can correctly access it
324+
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
325+
self._lora_scale = lora_scale
326+
319327
if prompt is not None and isinstance(prompt, str):
320328
batch_size = 1
321329
elif prompt is not None and isinstance(prompt, list):
@@ -838,6 +846,9 @@ def __call__(
838846
guess_mode = guess_mode or global_pool_conditions
839847

840848
# 3. Encode input prompt
849+
text_encoder_lora_scale = (
850+
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
851+
)
841852
prompt_embeds = self._encode_prompt(
842853
prompt,
843854
device,
@@ -846,6 +857,7 @@ def __call__(
846857
negative_prompt,
847858
prompt_embeds=prompt_embeds,
848859
negative_prompt_embeds=negative_prompt_embeds,
860+
lora_scale=text_encoder_lora_scale,
849861
)
850862

851863
# 4. Prepare image

pipelines/controlnet/pipeline_controlnet_img2img.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
2626

2727
from ...image_processor import VaeImageProcessor
28-
from ...loaders import TextualInversionLoaderMixin
28+
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
2929
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
3030
from ...schedulers import KarrasDiffusionSchedulers
3131
from ...utils import (
@@ -117,7 +117,7 @@ def prepare_image(image):
117117
return image
118118

119119

120-
class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
120+
class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
121121
r"""
122122
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
123123
@@ -317,6 +317,7 @@ def _encode_prompt(
317317
negative_prompt=None,
318318
prompt_embeds: Optional[torch.FloatTensor] = None,
319319
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
320+
lora_scale: Optional[float] = None,
320321
):
321322
r"""
322323
Encodes the prompt into text encoder hidden states.
@@ -341,7 +342,14 @@ def _encode_prompt(
341342
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
342343
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
343344
argument.
345+
lora_scale (`float`, *optional*):
346+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
344347
"""
348+
# set lora scale so that monkey patched LoRA
349+
# function of text encoder can correctly access it
350+
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
351+
self._lora_scale = lora_scale
352+
345353
if prompt is not None and isinstance(prompt, str):
346354
batch_size = 1
347355
elif prompt is not None and isinstance(prompt, list):
@@ -929,6 +937,9 @@ def __call__(
929937
guess_mode = guess_mode or global_pool_conditions
930938

931939
# 3. Encode input prompt
940+
text_encoder_lora_scale = (
941+
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
942+
)
932943
prompt_embeds = self._encode_prompt(
933944
prompt,
934945
device,
@@ -937,6 +948,7 @@ def __call__(
937948
negative_prompt,
938949
prompt_embeds=prompt_embeds,
939950
negative_prompt_embeds=negative_prompt_embeds,
951+
lora_scale=text_encoder_lora_scale,
940952
)
941953
# 4. Prepare image
942954
image = self.image_processor.preprocess(image).to(dtype=torch.float32)

pipelines/controlnet/pipeline_controlnet_inpaint.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
2727

2828
from ...image_processor import VaeImageProcessor
29-
from ...loaders import TextualInversionLoaderMixin
29+
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
3030
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
3131
from ...schedulers import KarrasDiffusionSchedulers
3232
from ...utils import (
@@ -223,7 +223,7 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False
223223
return mask, masked_image
224224

225225

226-
class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
226+
class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
227227
r"""
228228
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
229229
@@ -434,6 +434,7 @@ def _encode_prompt(
434434
negative_prompt=None,
435435
prompt_embeds: Optional[torch.FloatTensor] = None,
436436
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
437+
lora_scale: Optional[float] = None,
437438
):
438439
r"""
439440
Encodes the prompt into text encoder hidden states.
@@ -458,7 +459,14 @@ def _encode_prompt(
458459
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
459460
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
460461
argument.
462+
lora_scale (`float`, *optional*):
463+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
461464
"""
465+
# set lora scale so that monkey patched LoRA
466+
# function of text encoder can correctly access it
467+
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
468+
self._lora_scale = lora_scale
469+
462470
if prompt is not None and isinstance(prompt, str):
463471
batch_size = 1
464472
elif prompt is not None and isinstance(prompt, list):
@@ -1131,6 +1139,9 @@ def __call__(
11311139
guess_mode = guess_mode or global_pool_conditions
11321140

11331141
# 3. Encode input prompt
1142+
text_encoder_lora_scale = (
1143+
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
1144+
)
11341145
prompt_embeds = self._encode_prompt(
11351146
prompt,
11361147
device,
@@ -1139,6 +1150,7 @@ def __call__(
11391150
negative_prompt,
11401151
prompt_embeds=prompt_embeds,
11411152
negative_prompt_embeds=negative_prompt_embeds,
1153+
lora_scale=text_encoder_lora_scale,
11421154
)
11431155

11441156
# 4. Prepare image

0 commit comments

Comments
 (0)