Skip to content

Commit 983e032

Browse files
committed
remove changes from pipeline_stable_diffusion as part of imagic pipeline
1 parent 9c6bf05 commit 983e032

File tree

2 files changed

+5
-10
lines changed

2 files changed

+5
-10
lines changed

examples/community/imagic_stable_diffusion.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
modeled after the textual_inversion.py / train_dreambooth.py and the work
33
of justinpinkney here: https://github.com/justinpinkney/stable-diffusion/blob/main/notebooks/imagic.ipynb
44
"""
5+
import inspect
56
import warnings
67
from typing import List, Optional, Union
7-
import inspect
88

99
import numpy as np
1010
import torch
@@ -18,9 +18,10 @@
1818
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
1919
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
2020
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
21+
from diffusers.utils import logging
2122
from tqdm.auto import tqdm
2223
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
23-
from diffusers.utils import logging
24+
2425

2526
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2627

@@ -119,7 +120,6 @@ def disable_attention_slicing(self):
119120
# set slice_size = `None` to disable `attention slicing`
120121
self.enable_attention_slicing(None)
121122

122-
123123
def train(
124124
self,
125125
prompt: Union[str, List[str]],
@@ -379,14 +379,13 @@ def __call__(
379379
"""
380380
if height % 8 != 0 or width % 8 != 0:
381381
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
382-
if self.text_embeddings is None:
382+
if self.text_embeddings is None:
383383
raise ValueError("Please run the pipe.train() before trying to generate an image.")
384384
if self.text_embeddings_orig is None:
385385
raise ValueError("Please run the pipe.train() before trying to generate an image.")
386386

387387
text_embeddings = alpha * self.text_embeddings_orig + (1 - alpha) * self.text_embeddings
388388

389-
390389
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
391390
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
392391
# corresponds to doing no classifier free guidance.
@@ -487,4 +486,4 @@ def __call__(
487486
if not return_dict:
488487
return (image, has_nsfw_concept)
489488

490-
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
489+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,6 @@ def __call__(
174174
return_dict: bool = True,
175175
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
176176
callback_steps: Optional[int] = 1,
177-
text_embeddings: Optional[torch.FloatTensor] = None,
178177
**kwargs,
179178
):
180179
r"""
@@ -266,9 +265,6 @@ def __call__(
266265
)
267266
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
268267

269-
if text_embeddings is None:
270-
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
271-
272268
# duplicate text embeddings for each generation per prompt, using mps friendly method
273269
bs_embed, seq_len, _ = text_embeddings.shape
274270
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)

0 commit comments

Comments
 (0)