From 1e19e0fbd49def0426198186a304277e0051624c Mon Sep 17 00:00:00 2001 From: NormXU Date: Fri, 14 Apr 2023 18:55:59 +0800 Subject: [PATCH 1/3] add TextInversion for neg/prompt --- examples/community/lpw_stable_diffusion.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py index e912ad5244be..83d6bfcc0e3a 100644 --- a/examples/community/lpw_stable_diffusion.py +++ b/examples/community/lpw_stable_diffusion.py @@ -13,7 +13,7 @@ from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker from diffusers.utils import logging - +from diffusers.loaders import TextualInversionLoaderMixin try: from diffusers.utils import PIL_INTERPOLATION @@ -539,6 +539,11 @@ def _encode_prompt( " the batch size of `prompt`." ) + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + negative_prompt = self.maybe_convert_prompt(negative_prompt, self.tokenizer) + text_embeddings, uncond_embeddings = get_weighted_text_embeddings( pipe=self, prompt=prompt, From 70a95687f6f08f4fc7013fa08be81615c9c17db2 Mon Sep 17 00:00:00 2001 From: NormXU Date: Fri, 14 Apr 2023 19:12:16 +0800 Subject: [PATCH 2/3] fix code format --- examples/community/lpw_stable_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py index 83d6bfcc0e3a..19a6492120a9 100644 --- a/examples/community/lpw_stable_diffusion.py +++ b/examples/community/lpw_stable_diffusion.py @@ -10,10 +10,10 @@ import diffusers from diffusers import SchedulerMixin, StableDiffusionPipeline +from diffusers.loaders import TextualInversionLoaderMixin from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker from diffusers.utils import logging -from diffusers.loaders import TextualInversionLoaderMixin try: from diffusers.utils import PIL_INTERPOLATION From 9a26cac53cdfbb9bd46c7399927aad21a9883309 Mon Sep 17 00:00:00 2001 From: Nuo Xu Date: Tue, 25 Apr 2023 00:24:29 +0800 Subject: [PATCH 3/3] fix type & add text inversion for lpw_stable_diffusion_onnx.py --- examples/community/lpw_stable_diffusion.py | 2 +- examples/community/lpw_stable_diffusion_onnx.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py index 19a6492120a9..086cdf4844f7 100644 --- a/examples/community/lpw_stable_diffusion.py +++ b/examples/community/lpw_stable_diffusion.py @@ -539,7 +539,7 @@ def _encode_prompt( " the batch size of `prompt`." ) - # textual inversion: procecss multi-vector tokens if necessary + # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) negative_prompt = self.maybe_convert_prompt(negative_prompt, self.tokenizer) diff --git a/examples/community/lpw_stable_diffusion_onnx.py b/examples/community/lpw_stable_diffusion_onnx.py index e756097cb7c3..e04a691d7cd3 100644 --- a/examples/community/lpw_stable_diffusion_onnx.py +++ b/examples/community/lpw_stable_diffusion_onnx.py @@ -10,6 +10,7 @@ import diffusers from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, SchedulerMixin +from diffusers.loaders import TextualInversionLoaderMixin from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.utils import logging @@ -526,6 +527,11 @@ def _encode_prompt( " the batch size of `prompt`." ) + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + negative_prompt = self.maybe_convert_prompt(negative_prompt, self.tokenizer) + text_embeddings, uncond_embeddings = get_weighted_text_embeddings( pipe=self, prompt=prompt,