diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py index e912ad5244be..086cdf4844f7 100644 --- a/examples/community/lpw_stable_diffusion.py +++ b/examples/community/lpw_stable_diffusion.py @@ -10,11 +10,11 @@ import diffusers from diffusers import SchedulerMixin, StableDiffusionPipeline +from diffusers.loaders import TextualInversionLoaderMixin from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker from diffusers.utils import logging - try: from diffusers.utils import PIL_INTERPOLATION except ImportError: @@ -539,6 +539,11 @@ def _encode_prompt( " the batch size of `prompt`." ) + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + negative_prompt = self.maybe_convert_prompt(negative_prompt, self.tokenizer) + text_embeddings, uncond_embeddings = get_weighted_text_embeddings( pipe=self, prompt=prompt, diff --git a/examples/community/lpw_stable_diffusion_onnx.py b/examples/community/lpw_stable_diffusion_onnx.py index e756097cb7c3..e04a691d7cd3 100644 --- a/examples/community/lpw_stable_diffusion_onnx.py +++ b/examples/community/lpw_stable_diffusion_onnx.py @@ -10,6 +10,7 @@ import diffusers from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, SchedulerMixin +from diffusers.loaders import TextualInversionLoaderMixin from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.utils import logging @@ -526,6 +527,11 @@ def _encode_prompt( " the batch size of `prompt`." ) + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + negative_prompt = self.maybe_convert_prompt(negative_prompt, self.tokenizer) + text_embeddings, uncond_embeddings = get_weighted_text_embeddings( pipe=self, prompt=prompt,